From ddbfaeccba2be01fe0e54cacd29c058fdf5359e3 Mon Sep 17 00:00:00 2001 From: Solomon Choe <128758390+ywc88@users.noreply.github.com> Date: Tue, 22 Aug 2023 11:42:44 -0700 Subject: [PATCH 01/20] fix(go/adbc/driver/flightsql): Have GetTableSchema check for table name match instead of the first schema it receives (#980) Fixes #934. --- go/adbc/driver/flightsql/flightsql_adbc.go | 42 ++++++--- .../flightsql/flightsql_adbc_server_test.go | 92 +++++++++++++++++++ 2 files changed, 122 insertions(+), 12 deletions(-) diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go index 1ae99a6a55..e00310cfc3 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc.go +++ b/go/adbc/driver/flightsql/flightsql_adbc.go @@ -1231,24 +1231,42 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)") } - if rec.NumRows() == 0 { + numRows := rec.NumRows() + switch { + case numRows == 0: return nil, adbc.Error{ Code: adbc.StatusNotFound, } + case numRows > math.MaxInt32: + return nil, adbc.Error{ + Msg: "[Flight SQL] GetTableSchema cannot handle tables with number of rows > 2^31 - 1", + Code: adbc.StatusNotImplemented, + } } - // returned schema should be - // 0: catalog_name: utf8 - // 1: db_schema_name: utf8 - // 2: table_name: utf8 not null - // 3: table_type: utf8 not null - // 4: table_schema: bytes not null - schemaBytes := rec.Column(4).(*array.Binary).Value(0) - s, err := flight.DeserializeSchema(schemaBytes, c.db.alloc) - if err != nil { - return nil, adbcFromFlightStatus(err, "GetTableSchema") + var s *arrow.Schema + for i := 0; i < int(numRows); i++ { + currentTableName := rec.Column(2).(*array.String).Value(i) + if currentTableName == tableName { + // returned schema should be + // 0: catalog_name: utf8 + // 1: db_schema_name: utf8 + // 2: table_name: utf8 not null + // 3: table_type: utf8 not null + // 4: table_schema: bytes not null + schemaBytes := rec.Column(4).(*array.Binary).Value(i) + s, err = flight.DeserializeSchema(schemaBytes, c.db.alloc) + if err != nil { + return nil, adbcFromFlightStatus(err, "GetTableSchema") + } + return s, nil + } + } + + return s, adbc.Error{ + Msg: "[Flight SQL] GetTableSchema could not find a table with a matching schema", + Code: adbc.StatusNotFound, } - return s, nil } // GetTableTypes returns a list of the table types in the database. diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index dd6171c4cd..d8af6a6579 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -35,6 +35,7 @@ import ( "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/flight" "github.com/apache/arrow/go/v13/arrow/flight/flightsql" + "github.com/apache/arrow/go/v13/arrow/flight/flightsql/schema_ref" "github.com/apache/arrow/go/v13/arrow/memory" "github.com/stretchr/testify/suite" "golang.org/x/exp/maps" @@ -107,6 +108,10 @@ func TestDataType(t *testing.T) { suite.Run(t, &DataTypeTests{}) } +func TestMultiTable(t *testing.T) { + suite.Run(t, &MultiTableTests{}) +} + // ---- AuthN Tests -------------------- type AuthnTestServer struct { @@ -627,3 +632,90 @@ func (suite *DataTypeTests) TestListInt() { func (suite *DataTypeTests) TestMapIntInt() { suite.DoTestCase("map[int]int", SchemaMapIntInt) } + +// ---- Multi Table Tests -------------------- + +type MultiTableTestServer struct { + flightsql.BaseServer +} + +func (server *MultiTableTestServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + query := cmd.GetQuery() + tkt, err := flightsql.CreateStatementQueryTicket([]byte(query)) + if err != nil { + return nil, err + } + + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (server *MultiTableTestServer) GetFlightInfoTables(ctx context.Context, cmd flightsql.GetTables, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + schema := schema_ref.Tables + if cmd.GetIncludeSchema() { + schema = schema_ref.TablesWithIncludedSchema + } + server.Alloc = memory.NewCheckedAllocator(memory.DefaultAllocator) + info := &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{ + {Ticket: &flight.Ticket{Ticket: desc.Cmd}}, + }, + FlightDescriptor: desc, + Schema: flight.SerializeSchema(schema, server.Alloc), + TotalRecords: -1, + TotalBytes: -1, + } + + return info, nil +} + +func (server *MultiTableTestServer) DoGetTables(ctx context.Context, cmd flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) { + bldr := array.NewRecordBuilder(server.Alloc, adbc.GetTableSchemaSchema) + + bldr.Field(0).(*array.StringBuilder).AppendValues([]string{"", ""}, nil) + bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"", ""}, nil) + bldr.Field(2).(*array.StringBuilder).AppendValues([]string{"tbl1", "tbl2"}, nil) + bldr.Field(3).(*array.StringBuilder).AppendValues([]string{"", ""}, nil) + + sc1 := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + sc2 := arrow.NewSchema([]arrow.Field{{Name: "b", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + buf1 := flight.SerializeSchema(sc1, server.Alloc) + buf2 := flight.SerializeSchema(sc2, server.Alloc) + + bldr.Field(4).(*array.BinaryBuilder).AppendValues([][]byte{buf1, buf2}, nil) + defer bldr.Release() + + rec := bldr.NewRecord() + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: rec, + Desc: nil, + Err: nil, + } + }() + return adbc.GetTableSchemaSchema, ch, nil +} + +type MultiTableTests struct { + ServerBasedTests +} + +func (suite *MultiTableTests) SetupSuite() { + suite.DoSetupSuite(&MultiTableTestServer{}, nil, map[string]string{}) +} + +// Regression test for https://github.com/apache/arrow-adbc/issues/934 +func (suite *MultiTableTests) TestGetTableSchema() { + actualSchema, err := suite.cnxn.GetTableSchema(context.Background(), nil, nil, "tbl2") + suite.NoError(err) + + expectedSchema := arrow.NewSchema([]arrow.Field{{Name: "b", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + suite.Equal(expectedSchema, actualSchema) +} From d901b87688ff2f34f169584daccf8a44300d8c92 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 23 Aug 2023 09:25:18 -0400 Subject: [PATCH 02/20] docs: pin furo version (#988) Fixes #987. --- ci/conda_env_docs.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/conda_env_docs.txt b/ci/conda_env_docs.txt index 008b61d17a..ad1c9939f8 100644 --- a/ci/conda_env_docs.txt +++ b/ci/conda_env_docs.txt @@ -17,7 +17,8 @@ breathe doxygen -furo +# XXX(https://github.com/apache/arrow-adbc/issues/987) +furo=2023.07.26 make numpydoc pytest From 19a30aea784dfdce1fed506604f0272a222a3294 Mon Sep 17 00:00:00 2001 From: Solomon Choe <128758390+ywc88@users.noreply.github.com> Date: Wed, 23 Aug 2023 11:25:57 -0700 Subject: [PATCH 03/20] feat(python/adbc_driver_manager): add fetch_record_batch (#989) Fixes #968 --------- Co-authored-by: David Li --- .../adbc_driver_manager/dbapi.py | 26 ++++++++++++++++--- .../adbc_driver_manager/tests/test_dbapi.py | 20 ++++++++++++++ 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 31e4392ae5..60bc2d1bce 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -926,6 +926,24 @@ def fetch_df(self) -> "pandas.DataFrame": ) return self._results.fetch_df() + def fetch_record_batch(self) -> pyarrow.RecordBatchReader: + """ + Fetch the result as a PyArrow RecordBatchReader. + + This implements a similar API as DuckDB: + https://duckdb.org/docs/guides/python/export_arrow.html#export-as-a-recordbatchreader + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + if self._results is None: + raise ProgrammingError( + "Cannot fetch_record_batch() before execute()", + status_code=_lib.AdbcStatusCode.INVALID_STATE, + ) + return self._results._reader + # ---------------------------------------------------------- # Utilities @@ -973,7 +991,7 @@ def fetchone(self) -> Optional[tuple]: self.rownumber += 1 return row - def fetchmany(self, size: int): + def fetchmany(self, size: int) -> List[tuple]: rows = [] for _ in range(size): row = self.fetchone() @@ -982,7 +1000,7 @@ def fetchmany(self, size: int): rows.append(row) return rows - def fetchall(self): + def fetchall(self) -> List[tuple]: rows = [] while True: row = self.fetchone() @@ -991,10 +1009,10 @@ def fetchall(self): rows.append(row) return rows - def fetch_arrow_table(self): + def fetch_arrow_table(self) -> pyarrow.Table: return self._reader.read_all() - def fetch_df(self): + def fetch_df(self) -> "pandas.DataFrame": return self._reader.read_pandas() diff --git a/python/adbc_driver_manager/tests/test_dbapi.py b/python/adbc_driver_manager/tests/test_dbapi.py index 1eba12fda3..a29a661a23 100644 --- a/python/adbc_driver_manager/tests/test_dbapi.py +++ b/python/adbc_driver_manager/tests/test_dbapi.py @@ -294,6 +294,26 @@ def test_executemany(sqlite): assert next(cur) == (5, 6) +@pytest.mark.sqlite +def test_fetch_record_batch(sqlite): + dataset = [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + ] + with sqlite.cursor() as cur: + cur.execute("CREATE TABLE foo (a, b)") + cur.executemany( + "INSERT INTO foo VALUES (?, ?)", + dataset, + ) + cur.execute("SELECT * FROM foo") + rbr = cur.fetch_record_batch() + assert rbr.read_pandas().values.tolist() == dataset + + @pytest.mark.sqlite def test_fetch_empty(sqlite): with sqlite.cursor() as cur: From 60997c1b0bfc25883851a5870b4a0ac30e728f3a Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 23 Aug 2023 14:57:04 -0400 Subject: [PATCH 04/20] chore: bump to arrow-go v13 (#990) Fixes #927. --- go/adbc/driver/flightsql/flightsql_adbc_test.go | 2 +- go/adbc/go.mod | 2 +- go/adbc/go.sum | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go index 2f96093408..ee7f92802e 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go @@ -252,7 +252,7 @@ func (s *FlightSQLQuirks) GetMetadata(code adbc.InfoCode) interface{} { case adbc.InfoVendorVersion: return "sqlite 3" case adbc.InfoVendorArrowVersion: - return "13.0.0-SNAPSHOT" + return "13.0.0" } return nil diff --git a/go/adbc/go.mod b/go/adbc/go.mod index 999a8d47f1..b1bb5234c0 100644 --- a/go/adbc/go.mod +++ b/go/adbc/go.mod @@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc go 1.18 require ( - github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355 + github.com/apache/arrow/go/v13 v13.0.0 github.com/bluele/gcache v0.0.2 github.com/google/uuid v1.3.0 github.com/snowflakedb/gosnowflake v1.6.22 diff --git a/go/adbc/go.sum b/go/adbc/go.sum index 7e47f67f5c..83f49cc229 100644 --- a/go/adbc/go.sum +++ b/go/adbc/go.sum @@ -17,8 +17,8 @@ github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/ github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/apache/arrow/go/v12 v12.0.1 h1:JsR2+hzYYjgSUkBSaahpqCetqZMr76djX80fF/DiJbg= github.com/apache/arrow/go/v12 v12.0.1/go.mod h1:weuTY7JvTG/HDPtMQxEUp7pU73vkLWMLpY67QwZ/WWw= -github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355 h1:QuXqLb2HzL5EjY99fFp+iG9NagAruvQIbU/2++x+2VY= -github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355/go.mod h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc= +github.com/apache/arrow/go/v13 v13.0.0 h1:kELrvDQuKZo8csdWYqBQfyi431x6Zs/YJTEgUuSVcWk= +github.com/apache/arrow/go/v13 v13.0.0/go.mod h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc= github.com/apache/thrift v0.17.0 h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo= github.com/apache/thrift v0.17.0/go.mod h1:OLxhMRJxomX+1I/KUw03qoV3mMz16BwaKI+d4fPBx7Q= github.com/aws/aws-sdk-go-v2 v1.19.0 h1:klAT+y3pGFBU/qVf1uzwttpBbiuozJYWzNLHioyDJ+k= From db0b9c111ed8721db512ff2ac493b3f050610284 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sun, 27 Aug 2023 10:07:44 +0900 Subject: [PATCH 05/20] docs: add APT/Yum repositories to installation pages (#992) Fixes #991. --- docs/source/cpp/driver_manager.rst | 64 ++++++++++++++++++++++++++++- docs/source/driver/installation.rst | 64 ++++++++++++++++++++++++++++- 2 files changed, 125 insertions(+), 3 deletions(-) diff --git a/docs/source/cpp/driver_manager.rst b/docs/source/cpp/driver_manager.rst index 120e5dd5f0..d8db791d1f 100644 --- a/docs/source/cpp/driver_manager.rst +++ b/docs/source/cpp/driver_manager.rst @@ -27,7 +27,69 @@ specific driver. Installation ============ -TODO +Install the appropriate driver package. You can use conda-forge_, ``apt`` or ``dnf``. + +conda-forge: + +- ``mamba install adbc-driver-manager`` + +You can use ``apt`` on the following platforms: + +- Debian GNU/Linux bookworm +- Ubuntu 22.04 + +Prepare the Apache Arrow APT repository: + +.. code-block:: bash + + sudo apt update + sudo apt install -y -V ca-certificates lsb-release wget + sudo wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + rm ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt update + +Install: + +- ``sudo apt install libadbc-driver-manager-dev`` + +You can use ``dnf`` on the following platforms: + +- AlmaLinux 8 +- Oracle Linux 8 +- Red Hat Enterprise Linux 8 +- AlmaLinux 9 +- Oracle Linux 9 +- Red Hat Enterprise Linux 9 + +Prepare the Apache Arrow Yum repository: + +.. code-block:: bash + + sudo dnf install -y epel-release || sudo dnf install -y oracle-epel-release-el$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1) || sudo dnf install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1).noarch.rpm + sudo dnf install -y https://apache.jfrog.io/artifactory/arrow/almalinux/$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)/apache-arrow-release-latest.rpm + sudo dnf config-manager --set-enabled epel || : + sudo dnf config-manager --set-enabled powertools || : + sudo dnf config-manager --set-enabled crb || : + sudo dnf config-manager --set-enabled ol$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)_codeready_builder || : + sudo dnf config-manager --set-enabled codeready-builder-for-rhel-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)-rhui-rpms || : + sudo subscription-manager repos --enable codeready-builder-for-rhel-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)-$(arch)-rpms || : + +Install: + +- ``sudo dnf install adbc-driver-manager-devel`` + +Then they can be used via CMake, e.g.: + +.. code-block:: cmake + + find_package(AdbcDriverPostgreSQL) + + # ... + + target_link_libraries(myapp PRIVATE AdbcDriverPostgreSQL::adbc_driver_postgresql_shared) + +.. _conda-forge: https://conda-forge.org/ Usage ===== diff --git a/docs/source/driver/installation.rst b/docs/source/driver/installation.rst index ffd3a8c9a5..7540747774 100644 --- a/docs/source/driver/installation.rst +++ b/docs/source/driver/installation.rst @@ -26,12 +26,66 @@ Installation C/C++ ===== -Install the appropriate driver package. These are currently only available from conda-forge_: +Install the appropriate driver package. You can use conda-forge_, ``apt`` or ``dnf``. + +conda-forge: - ``mamba install libadbc-driver-flightsql`` - ``mamba install libadbc-driver-postgresql`` - ``mamba install libadbc-driver-sqlite`` +You can use ``apt`` on the following platforms: + +- Debian GNU/Linux bookworm +- Ubuntu 22.04 + +Prepare the Apache Arrow APT repository: + +.. code-block:: bash + + sudo apt update + sudo apt install -y -V ca-certificates lsb-release wget + sudo wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + rm ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt update + +Install: + +- ``sudo apt install libadbc-driver-flightsql-dev`` +- ``sudo apt install libadbc-driver-postgresql-dev`` +- ``sudo apt install libadbc-driver-sqlite-dev`` +- ``sudo apt install libadbc-driver-snowflake-dev`` + +You can use ``dnf`` on the following platforms: + +- AlmaLinux 8 +- Oracle Linux 8 +- Red Hat Enterprise Linux 8 +- AlmaLinux 9 +- Oracle Linux 9 +- Red Hat Enterprise Linux 9 + +Prepare the Apache Arrow Yum repository: + +.. code-block:: bash + + sudo dnf install -y epel-release || sudo dnf install -y oracle-epel-release-el$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1) || sudo dnf install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1).noarch.rpm + sudo dnf install -y https://apache.jfrog.io/artifactory/arrow/almalinux/$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)/apache-arrow-release-latest.rpm + sudo dnf config-manager --set-enabled epel || : + sudo dnf config-manager --set-enabled powertools || : + sudo dnf config-manager --set-enabled crb || : + sudo dnf config-manager --set-enabled ol$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)_codeready_builder || : + sudo dnf config-manager --set-enabled codeready-builder-for-rhel-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)-rhui-rpms || : + sudo subscription-manager repos --enable codeready-builder-for-rhel-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)-$(arch)-rpms || : + +Install: + +- ``sudo dnf install adbc-driver-flightsql-devel`` +- ``sudo dnf install adbc-driver-postgresql-devel`` +- ``sudo dnf install adbc-driver-sqlite-devel`` +- ``sudo dnf install adbc-driver-snowflake-devel`` + Then they can be used via CMake, e.g.: .. code-block:: cmake @@ -80,7 +134,7 @@ From conda-forge_: - ``mamba install adbc-driver-sqlite`` R -====== += Install the appropriate driver package from GitHub: @@ -94,3 +148,9 @@ Install the appropriate driver package from GitHub: Installation of stable releases from CRAN is anticipated following the release of ADBC Libraries 0.6.0. + +Ruby +==== + +Install the appropriate driver package for C/C++. You can use it from +Ruby. From 370aac6622df5b30c8fa8fa6c0c7d3a61e2520d7 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Aug 2023 08:58:19 -0400 Subject: [PATCH 06/20] chore: bump versions to 0.7.0 (#993) --- CHANGELOG.md | 50 +++++++++++++++++++ c/cmake_modules/AdbcVersion.cmake | 2 +- ci/conda/meta.yaml | 2 +- ci/linux-packages/debian/control | 22 ++++---- ...ll => libadbc-driver-flightsql007.install} | 0 ...tall => libadbc-driver-manager007.install} | 0 ...l => libadbc-driver-postgresql007.install} | 0 ...ll => libadbc-driver-snowflake007.install} | 0 ...stall => libadbc-driver-sqlite007.install} | 0 docs/source/conf.py | 2 +- glib/meson.build | 2 +- java/core/pom.xml | 2 +- java/driver-manager/pom.xml | 2 +- java/driver/flight-sql-validation/pom.xml | 2 +- java/driver/flight-sql/pom.xml | 2 +- java/driver/jdbc-validation-derby/pom.xml | 2 +- .../jdbc-validation-mssqlserver/pom.xml | 2 +- .../driver/jdbc-validation-postgresql/pom.xml | 2 +- java/driver/jdbc/pom.xml | 2 +- java/driver/validation/pom.xml | 2 +- java/pom.xml | 4 +- java/sql/pom.xml | 2 +- r/adbcdrivermanager/DESCRIPTION | 2 +- r/adbcflightsql/DESCRIPTION | 2 +- r/adbcpostgresql/DESCRIPTION | 2 +- r/adbcsnowflake/DESCRIPTION | 2 +- r/adbcsqlite/DESCRIPTION | 2 +- ruby/lib/adbc/version.rb | 2 +- rust/Cargo.toml | 2 +- 29 files changed, 84 insertions(+), 34 deletions(-) rename ci/linux-packages/debian/{libadbc-driver-flightsql006.install => libadbc-driver-flightsql007.install} (100%) rename ci/linux-packages/debian/{libadbc-driver-manager006.install => libadbc-driver-manager007.install} (100%) rename ci/linux-packages/debian/{libadbc-driver-postgresql006.install => libadbc-driver-postgresql007.install} (100%) rename ci/linux-packages/debian/{libadbc-driver-snowflake006.install => libadbc-driver-snowflake007.install} (100%) rename ci/linux-packages/debian/{libadbc-driver-sqlite006.install => libadbc-driver-sqlite007.install} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0acdedfa81..190da9867b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -237,3 +237,53 @@ ### Perf - **go/adbc/driver/flightsql**: filter by schema in getObjectsTables (#726) + +## ADBC Libraries 0.6.0 (2023-08-23) + +### Feat + +- **python/adbc_driver_manager**: add fetch_record_batch (#989) +- **c/driver**: Date32 support (#948) +- **c/driver/postgresql**: Interval support (#908) +- **go/adbc/driver/flightsql**: add context to gRPC errors (#921) +- **c/driver/sqlite**: SQLite timestamp write support (#897) +- **c/driver/postgresql**: Handle NUMERIC type by converting to string (#883) +- **python/adbc_driver_postgresql**: add PostgreSQL options enum (#886) +- **c/driver/postgresql**: TimestampTz write (#868) +- **c/driver/postgresql**: Implement streaming/chunked output (#870) +- **c/driver/postgresql**: Timestamp write support (#861) +- **c/driver_manager,go/adbc,python**: trim down error messages (#866) +- **c/driver/postgresql**: Int8 support (#858) +- **c/driver/postgresql**: Better type error messages (#860) + +### Fix + +- **go/adbc/driver/flightsql**: Have GetTableSchema check for table name match instead of the first schema it receives (#980) +- **r**: Ensure that info_codes are coerced to integer (#986) +- **go/adbc/sqldriver**: fix handling of decimal types (#970) +- **c/driver/postgresql**: Fix segfault associated with uninitialized copy_reader_ (#964) +- **c/driver/sqlite**: add table types by default from arrow types (#955) +- **csharp**: include GetTableTypes and GetTableSchema call for .NET 4.7.2 (#950) +- **csharp**: include GetInfo and GetObjects call for .NET 4.7.2 (#945) +- **c/driver/sqlite**: Wrap bulk ingests in a single begin/commit txn (#910) +- **csharp**: fix C api to work under .NET 4.7.2 (#931) +- **python/adbc_driver_snowflake**: allow connecting without URI (#923) +- **go/adbc/pkg**: export Adbc* symbols on Windows (#916) +- **go/adbc/driver/snowflake**: handle non-arrow result sets (#909) +- **c/driver/sqlite**: fix escaping of sqlite TABLE CREATE columns (#906) +- **go/adbc/pkg**: follow CGO rules properly (#902) +- **go/adbc/driver/snowflake**: Fix integration tests by fixing timestamp handling (#889) +- **go/adbc/driver/snowflake**: fix failing integration tests (#888) +- **c/validation**: Fix ASAN-detected leak (#879) +- **go/adbc**: fix crash on map type (#854) +- **go/adbc/driver/snowflake**: handle result sets without Arrow data (#864) + +### Perf + +- **go/adbc/driver/snowflake**: Implement concurrency limit (#974) + +### Refactor + +- **c**: Vendor portable-snippets for overflow checks (#951) +- **c/driver/postgresql**: Use ArrowArrayViewGetIntervalUnsafe from nanoarrow (#957) +- **c/driver/postgresql**: Simplify current database querying (#880) diff --git a/c/cmake_modules/AdbcVersion.cmake b/c/cmake_modules/AdbcVersion.cmake index 8a27457cd9..a39565cde8 100644 --- a/c/cmake_modules/AdbcVersion.cmake +++ b/c/cmake_modules/AdbcVersion.cmake @@ -21,7 +21,7 @@ # ------------------------------------------------------------ # Version definitions -set(ADBC_VERSION "0.6.0-SNAPSHOT") +set(ADBC_VERSION "0.7.0-SNAPSHOT") string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ADBC_BASE_VERSION "${ADBC_VERSION}") string(REPLACE "." ";" _adbc_version_list "${ADBC_BASE_VERSION}") list(GET _adbc_version_list 0 ADBC_VERSION_MAJOR) diff --git a/ci/conda/meta.yaml b/ci/conda/meta.yaml index dfa02ee0cb..66d66cbf50 100644 --- a/ci/conda/meta.yaml +++ b/ci/conda/meta.yaml @@ -18,7 +18,7 @@ package: name: arrow-adbc-split # TODO: this needs to get bumped by the release process - version: 0.6.0 + version: 0.7.0 source: path: ../../ diff --git a/ci/linux-packages/debian/control b/ci/linux-packages/debian/control index dbf6b6e3d1..2b4a06a55f 100644 --- a/ci/linux-packages/debian/control +++ b/ci/linux-packages/debian/control @@ -33,7 +33,7 @@ Build-Depends: Standards-Version: 4.5.0 Homepage: https://arrow.apache.org/adbc/ -Package: libadbc-driver-manager006 +Package: libadbc-driver-manager007 Section: libs Architecture: any Multi-Arch: same @@ -51,12 +51,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-manager006 (= ${binary:Version}) + libadbc-driver-manager007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) driver manager . This package provides C++ header files. -Package: libadbc-driver-postgresql006 +Package: libadbc-driver-postgresql007 Section: libs Architecture: any Multi-Arch: same @@ -74,12 +74,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-postgresql006 (= ${binary:Version}) + libadbc-driver-postgresql007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) PostgreSQL driver . This package provides CMake package, pkg-config package and so on. -Package: libadbc-driver-sqlite006 +Package: libadbc-driver-sqlite007 Section: libs Architecture: any Multi-Arch: same @@ -97,12 +97,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-sqlite006 (= ${binary:Version}) + libadbc-driver-sqlite007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) SQLite driver . This package provides CMake package, pkg-config package and so on. -Package: libadbc-driver-flightsql006 +Package: libadbc-driver-flightsql007 Section: libs Architecture: any Multi-Arch: same @@ -120,12 +120,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-flightsql006 (= ${binary:Version}) + libadbc-driver-flightsql007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) Flight SQL driver . This package provides CMake package, pkg-config package and so on. -Package: libadbc-driver-snowflake006 +Package: libadbc-driver-snowflake007 Section: libs Architecture: any Multi-Arch: same @@ -143,7 +143,7 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-snowflake006 (= ${binary:Version}) + libadbc-driver-snowflake007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) Snowflake driver . This package provides CMake package, pkg-config package and so on. @@ -157,7 +157,7 @@ Pre-Depends: ${misc:Pre-Depends} Depends: ${misc:Depends}, ${shlibs:Depends}, - libadbc-driver-manager006 (= ${binary:Version}) + libadbc-driver-manager007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) driver manager . This package provides GLib based library files. diff --git a/ci/linux-packages/debian/libadbc-driver-flightsql006.install b/ci/linux-packages/debian/libadbc-driver-flightsql007.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-flightsql006.install rename to ci/linux-packages/debian/libadbc-driver-flightsql007.install diff --git a/ci/linux-packages/debian/libadbc-driver-manager006.install b/ci/linux-packages/debian/libadbc-driver-manager007.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-manager006.install rename to ci/linux-packages/debian/libadbc-driver-manager007.install diff --git a/ci/linux-packages/debian/libadbc-driver-postgresql006.install b/ci/linux-packages/debian/libadbc-driver-postgresql007.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-postgresql006.install rename to ci/linux-packages/debian/libadbc-driver-postgresql007.install diff --git a/ci/linux-packages/debian/libadbc-driver-snowflake006.install b/ci/linux-packages/debian/libadbc-driver-snowflake007.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-snowflake006.install rename to ci/linux-packages/debian/libadbc-driver-snowflake007.install diff --git a/ci/linux-packages/debian/libadbc-driver-sqlite006.install b/ci/linux-packages/debian/libadbc-driver-sqlite007.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-sqlite006.install rename to ci/linux-packages/debian/libadbc-driver-sqlite007.install diff --git a/docs/source/conf.py b/docs/source/conf.py index b1c2dfa356..5c370f225b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ project = "ADBC" copyright = "2022, Apache Arrow Developers" author = "the Apache Arrow Developers" -release = "0.6.0 (dev)" +release = "0.7.0 (dev)" # Needed to generate version switcher version = release diff --git a/glib/meson.build b/glib/meson.build index 55d9a13a3c..7516c433dc 100644 --- a/glib/meson.build +++ b/glib/meson.build @@ -23,7 +23,7 @@ project('adbc-glib', 'c_std=c99', ], license: 'Apache-2.0', - version: '0.6.0-SNAPSHOT') + version: '0.7.0-SNAPSHOT') version_numbers = meson.project_version().split('-')[0].split('.') version_major = version_numbers[0].to_int() diff --git a/java/core/pom.xml b/java/core/pom.xml index 837119d527..742651be12 100644 --- a/java/core/pom.xml +++ b/java/core/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT adbc-core diff --git a/java/driver-manager/pom.xml b/java/driver-manager/pom.xml index 12d7a3f3c0..bfaba9ba7d 100644 --- a/java/driver-manager/pom.xml +++ b/java/driver-manager/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT adbc-driver-manager diff --git a/java/driver/flight-sql-validation/pom.xml b/java/driver/flight-sql-validation/pom.xml index 987108c2f3..57e685f2a1 100644 --- a/java/driver/flight-sql-validation/pom.xml +++ b/java/driver/flight-sql-validation/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/flight-sql/pom.xml b/java/driver/flight-sql/pom.xml index 432967963b..db7aef9900 100644 --- a/java/driver/flight-sql/pom.xml +++ b/java/driver/flight-sql/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/jdbc-validation-derby/pom.xml b/java/driver/jdbc-validation-derby/pom.xml index f273218c65..a97f8c21fa 100644 --- a/java/driver/jdbc-validation-derby/pom.xml +++ b/java/driver/jdbc-validation-derby/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/jdbc-validation-mssqlserver/pom.xml b/java/driver/jdbc-validation-mssqlserver/pom.xml index 5dd6b54acd..e0ba64cc4a 100644 --- a/java/driver/jdbc-validation-mssqlserver/pom.xml +++ b/java/driver/jdbc-validation-mssqlserver/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/jdbc-validation-postgresql/pom.xml b/java/driver/jdbc-validation-postgresql/pom.xml index 295ca5d4f0..1e0e5407c9 100644 --- a/java/driver/jdbc-validation-postgresql/pom.xml +++ b/java/driver/jdbc-validation-postgresql/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/jdbc/pom.xml b/java/driver/jdbc/pom.xml index 0c78818b1b..73ebad38e4 100644 --- a/java/driver/jdbc/pom.xml +++ b/java/driver/jdbc/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/validation/pom.xml b/java/driver/validation/pom.xml index 20c2169762..c70235002f 100644 --- a/java/driver/validation/pom.xml +++ b/java/driver/validation/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/pom.xml b/java/pom.xml index 9b2b1d7d47..fd244227fb 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -20,7 +20,7 @@ org.apache.arrow.adbc arrow-adbc-java-root - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT pom Apache Arrow ADBC Java Root POM @@ -29,7 +29,7 @@ 12.0.0 - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT diff --git a/java/sql/pom.xml b/java/sql/pom.xml index 2b13654ce9..4703f1df76 100644 --- a/java/sql/pom.xml +++ b/java/sql/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT adbc-sql diff --git a/r/adbcdrivermanager/DESCRIPTION b/r/adbcdrivermanager/DESCRIPTION index 3bb8509d3f..816aa909d8 100644 --- a/r/adbcdrivermanager/DESCRIPTION +++ b/r/adbcdrivermanager/DESCRIPTION @@ -1,6 +1,6 @@ Package: adbcdrivermanager Title: 'Arrow' Database Connectivity ('ADBC') Driver Manager -Version: 0.5.0.9000 +Version: 0.6.0.9000 Authors@R: c( person("Dewey", "Dunnington", , "dewey@dunnington.ca", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9415-4582")), diff --git a/r/adbcflightsql/DESCRIPTION b/r/adbcflightsql/DESCRIPTION index e4bbcda297..c041186a52 100644 --- a/r/adbcflightsql/DESCRIPTION +++ b/r/adbcflightsql/DESCRIPTION @@ -1,6 +1,6 @@ Package: adbcflightsql Title: 'Arrow' Database Connectivity ('ADBC') 'FlightSQL' Driver -Version: 0.5.0.9000 +Version: 0.6.0.9000 Authors@R: c( person("Dewey", "Dunnington", , "dewey@dunnington.ca", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9415-4582")), diff --git a/r/adbcpostgresql/DESCRIPTION b/r/adbcpostgresql/DESCRIPTION index 7a729f1b7b..86a1812cb2 100644 --- a/r/adbcpostgresql/DESCRIPTION +++ b/r/adbcpostgresql/DESCRIPTION @@ -1,6 +1,6 @@ Package: adbcpostgresql Title: 'Arrow' Database Connectivity ('ADBC') 'PostgreSQL' Driver -Version: 0.5.0.9000 +Version: 0.6.0.9000 Authors@R: c( person("Dewey", "Dunnington", , "dewey@dunnington.ca", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9415-4582")), diff --git a/r/adbcsnowflake/DESCRIPTION b/r/adbcsnowflake/DESCRIPTION index c21694d510..5351956a9f 100644 --- a/r/adbcsnowflake/DESCRIPTION +++ b/r/adbcsnowflake/DESCRIPTION @@ -1,6 +1,6 @@ Package: adbcsnowflake Title: Arrow Database Connectivity ('ADBC') 'Snowflake' Driver -Version: 0.5.0.9000 +Version: 0.6.0.9000 Authors@R: c( person("Dewey", "Dunnington", , "dewey@dunnington.ca", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9415-4582")), diff --git a/r/adbcsqlite/DESCRIPTION b/r/adbcsqlite/DESCRIPTION index 53d589a3e5..83bbf8f97a 100644 --- a/r/adbcsqlite/DESCRIPTION +++ b/r/adbcsqlite/DESCRIPTION @@ -1,6 +1,6 @@ Package: adbcsqlite Title: 'Arrow' Database Connectivity ('ADBC') 'SQLite' Driver -Version: 0.5.0.9000 +Version: 0.6.0.9000 Authors@R: c( person("Dewey", "Dunnington", , "dewey@dunnington.ca", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9415-4582")), diff --git a/ruby/lib/adbc/version.rb b/ruby/lib/adbc/version.rb index 3843fb710a..e853463bfe 100644 --- a/ruby/lib/adbc/version.rb +++ b/ruby/lib/adbc/version.rb @@ -16,7 +16,7 @@ # under the License. module ADBC - VERSION = "0.6.0-SNAPSHOT" + VERSION = "0.7.0-SNAPSHOT" module Version MAJOR, MINOR, MICRO, TAG = VERSION.split(".").collect(&:to_i) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index dc7a37dde5..4daec0cc35 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "arrow-adbc" -version = "0.6.0-SNAPSHOT" +version = "0.7.0-SNAPSHOT" edition = "2021" rust-version = "1.62" description = "Rust implementation of Arrow Database Connectivity (ADBC)" From e6319052af945fe910ee0265d02fc9048caf72ed Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Aug 2023 09:45:27 -0400 Subject: [PATCH 07/20] chore(dev/release): fix typos in release scripts (#994) --- dev/release/02-sign.sh | 4 ++-- dev/release/06-binary-verify.sh | 2 +- docs/source/development/releasing.rst | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dev/release/02-sign.sh b/dev/release/02-sign.sh index 6267ce8407..2eb8dce5c0 100755 --- a/dev/release/02-sign.sh +++ b/dev/release/02-sign.sh @@ -23,8 +23,8 @@ main() { local -r source_top_dir="$( cd "${source_dir}/../../" && pwd )" pushd "${source_top_dir}" - if [ "$#" -ne 2 ]; then - echo "Usage: $0 " + if [ "$#" -ne 3 ]; then + echo "Usage: $0 " exit 1 fi diff --git a/dev/release/06-binary-verify.sh b/dev/release/06-binary-verify.sh index 8f1a118249..2a9aad46e4 100755 --- a/dev/release/06-binary-verify.sh +++ b/dev/release/06-binary-verify.sh @@ -89,7 +89,7 @@ The vote will be open for at least 72 hours. [ ] +0 [ ] -1 Do not release this as Apache Arrow ADBC ${version} because... -Note: to verify APT/YUM packages on macOS/AArch64, you must \`export DOCKER_DEFAULT_ARCHITECTURE=linux/amd64\`. (Or skip this step by \`export TEST_APT=0 TEST_YUM=0\`.) +Note: to verify APT/YUM packages on macOS/AArch64, you must \`export DOCKER_DEFAULT_PLATFORM=linux/amd64\`. (Or skip this step by \`export TEST_APT=0 TEST_YUM=0\`.) [1]: https://github.com/apache/arrow-adbc/issues?q=is%3Aissue+milestone%3A%22ADBC+Libraries+${version}%22+is%3Aclosed [2]: https://github.com/apache/arrow-adbc/commit/${commit} diff --git a/docs/source/development/releasing.rst b/docs/source/development/releasing.rst index cef9fab8e4..6408b10bb4 100644 --- a/docs/source/development/releasing.rst +++ b/docs/source/development/releasing.rst @@ -275,7 +275,7 @@ Be sure to go through on the following checklist: # dev/release/post-01-upload.sh 0.1.0 0 dev/release/post-01-upload.sh - git push --tag apache apache-arrow-adbc- + git push apache apache-arrow-adbc- .. dropdown:: Create the final GitHub release :class-title: sd-fs-5 @@ -303,8 +303,8 @@ Be sure to go through on the following checklist: .. code-block:: Bash - # dev/release/post-03-python.sh 10.0.0 - dev/release/post-03-python.sh + # dev/release/post-03-python.sh 0.1.0 0 + dev/release/post-03-python.sh .. dropdown:: Publish Maven packages :class-title: sd-fs-5 From ba7032da87bd138e4b6310ad5004eb04aa85a481 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Aug 2023 10:56:23 -0400 Subject: [PATCH 08/20] feat: ADBC API revision 1.1.0 (#971) Fixes #55. Fixes #317. Fixes #318. Fixes #319. Fixes #442. Fixes #458. Fixes #459. Fixes #541. Fixes #620. Fixes #685. Fixes #736. Fixes #755. Fixes #939. Fixes #940. Fixes #942. Fixes #962. --------- Co-authored-by: Matt Topol Co-authored-by: Sutou Kouhei Co-authored-by: Will Jones --- .env | 2 +- .gitattributes | 5 + .github/workflows/native-unix.yml | 1 + .pre-commit-config.yaml | 4 +- adbc.h | 1108 ++++++++++++++- c/driver/common/utils.c | 202 ++- c/driver/common/utils.h | 18 +- c/driver/common/utils_test.cc | 95 ++ c/driver/flightsql/dremio_flightsql_test.cc | 3 +- c/driver/flightsql/sqlite_flightsql_test.cc | 177 ++- c/driver/postgresql/CMakeLists.txt | 1 + c/driver/postgresql/connection.cc | 511 ++++++- c/driver/postgresql/connection.h | 23 +- c/driver/postgresql/database.cc | 45 +- c/driver/postgresql/database.h | 19 + c/driver/postgresql/error.cc | 97 ++ c/driver/postgresql/error.h | 42 + c/driver/postgresql/postgres_copy_reader.h | 8 +- c/driver/postgresql/postgresql.cc | 472 ++++++- c/driver/postgresql/postgresql_test.cc | 385 +++++- c/driver/postgresql/statement.cc | 392 ++++-- c/driver/postgresql/statement.h | 41 +- c/driver/snowflake/snowflake_test.cc | 8 +- c/driver/sqlite/sqlite.c | 312 ++++- c/driver/sqlite/sqlite_test.cc | 20 + c/driver_manager/CMakeLists.txt | 19 +- c/driver_manager/adbc_driver_manager.cc | 1200 ++++++++++++++--- c/driver_manager/adbc_driver_manager_test.cc | 61 +- c/driver_manager/adbc_version_100.c | 117 ++ c/driver_manager/adbc_version_100.h | 94 ++ .../adbc_version_100_compatibility_test.cc | 111 ++ c/integration/duckdb/CMakeLists.txt | 1 + c/integration/duckdb/duckdb_test.cc | 8 +- c/symbols.map | 10 + c/validation/CMakeLists.txt | 19 +- c/validation/adbc_validation.cc | 611 +++++++-- c/validation/adbc_validation.h | 102 +- c/validation/adbc_validation_util.h | 1 + ci/docker/python-wheel-manylinux.dockerfile | 4 +- ci/scripts/python_wheel_unix_build.sh | 2 +- dev/release/verify-release-candidate.sh | 8 +- docker-compose.yml | 26 +- docs/source/format/specification.rst | 111 ++ docs/source/format/versioning.rst | 28 +- .../source/python/api/adbc_driver_manager.rst | 2 + .../recipe/postgresql_create_append_table.py | 2 +- go/adbc/adbc.go | 265 +++- go/adbc/driver/flightsql/flightsql_adbc.go | 487 +++++-- .../flightsql/flightsql_adbc_server_test.go | 268 ++++ .../driver/flightsql/flightsql_adbc_test.go | 17 +- .../driver/flightsql/flightsql_statement.go | 219 ++- go/adbc/driver/flightsql/record_reader.go | 20 +- go/adbc/driver/flightsql/utils.go | 63 +- go/adbc/driver/snowflake/connection.go | 115 +- go/adbc/driver/snowflake/driver.go | 128 ++ go/adbc/driver/snowflake/driver_test.go | 25 +- go/adbc/driver/snowflake/statement.go | 116 +- go/adbc/drivermgr/adbc.h | 1108 ++++++++++++++- go/adbc/drivermgr/adbc_driver_manager.cc | 1200 ++++++++++++++--- go/adbc/infocode_string.go | 7 +- go/adbc/pkg/_tmpl/driver.go.tmpl | 1077 +++++++++++++-- go/adbc/pkg/_tmpl/utils.c.tmpl | 353 ++++- go/adbc/pkg/_tmpl/utils.h.tmpl | 76 +- go/adbc/pkg/flightsql/driver.go | 1077 +++++++++++++-- go/adbc/pkg/flightsql/utils.c | 356 ++++- go/adbc/pkg/flightsql/utils.h | 143 +- go/adbc/pkg/panicdummy/driver.go | 1077 +++++++++++++-- go/adbc/pkg/panicdummy/utils.c | 356 ++++- go/adbc/pkg/panicdummy/utils.h | 146 +- go/adbc/pkg/snowflake/driver.go | 1077 +++++++++++++-- go/adbc/pkg/snowflake/utils.c | 356 ++++- go/adbc/pkg/snowflake/utils.h | 143 +- go/adbc/standard_schemas.go | 28 + go/adbc/validation/validation.go | 341 ++++- .../arrow/adbc/core/AdbcConnection.java | 152 ++- .../apache/arrow/adbc/core/AdbcDatabase.java | 2 +- .../apache/arrow/adbc/core/AdbcDriver.java | 35 +- .../apache/arrow/adbc/core/AdbcException.java | 38 +- .../apache/arrow/adbc/core/AdbcInfoCode.java | 17 +- .../apache/arrow/adbc/core/AdbcOptions.java | 45 + .../apache/arrow/adbc/core/AdbcStatement.java | 62 +- .../arrow/adbc/core/BulkIngestMode.java | 15 +- .../apache/arrow/adbc/core/ErrorDetail.java | 60 + .../arrow/adbc/core/StandardSchemas.java | 79 +- .../arrow/adbc/core/StandardStatistics.java | 81 ++ .../org/apache/arrow/adbc/core/TypedKey.java | 87 ++ .../driver/flightsql/FlightSqlQuirks.java | 2 +- .../flightsql/FlightSqlStatementTest.java | 12 + java/driver/flight-sql/pom.xml | 12 + .../driver/flightsql/FlightSqlDriver.java | 17 +- .../driver/flightsql/FlightSqlDriverUtil.java | 22 +- .../driver/flightsql/FlightSqlStatement.java | 12 + .../adbc/driver/flightsql/DetailsTest.java | 381 ++++++ .../driver/jdbc/derby/DerbyStatementTest.java | 5 + .../jdbc/postgresql/PostgresqlQuirks.java | 27 +- .../jdbc/postgresql/StatisticsTest.java | 121 ++ .../adbc/driver/jdbc/InfoMetadataBuilder.java | 22 +- .../adbc/driver/jdbc/JdbcArrowReader.java | 15 +- .../adbc/driver/jdbc/JdbcConnection.java | 206 +++ .../arrow/adbc/driver/jdbc/JdbcStatement.java | 36 + .../testsuite/AbstractConnectionTest.java | 14 + .../testsuite/AbstractStatementTest.java | 57 + .../driver/testsuite/SqlValidationQuirks.java | 9 + python/adbc_driver_manager/MANIFEST.in | 3 + .../adbc_driver_manager/__init__.py | 10 + .../adbc_driver_manager/_lib.pxd | 287 ++++ .../adbc_driver_manager/_lib.pyi | 48 +- .../adbc_driver_manager/_lib.pyx | 690 ++++++---- .../adbc_driver_manager/_reader.pyi | 34 + .../adbc_driver_manager/_reader.pyx | 113 ++ .../adbc_driver_manager/dbapi.py | 92 +- python/adbc_driver_manager/setup.py | 16 +- .../adbc_driver_manager/tests/test_reader.py | 80 ++ .../tests/test_dbapi.py | 144 ++ .../tests/test_lowlevel.py | 12 +- r/adbcdrivermanager/src/driver_log.c | 3 +- r/adbcdrivermanager/src/driver_monkey.c | 2 +- r/adbcdrivermanager/src/driver_void.c | 2 +- r/adbcflightsql/tools/download-go.R | 2 +- r/adbcpostgresql/bootstrap.R | 2 + r/adbcpostgresql/src/.gitignore | 2 + r/adbcpostgresql/src/Makevars.in | 1 + r/adbcpostgresql/src/Makevars.ucrt | 1 + r/adbcpostgresql/src/Makevars.win | 1 + r/adbcsnowflake/configure | 2 +- r/adbcsnowflake/tools/download-go.R | 2 +- 126 files changed, 18453 insertions(+), 1911 deletions(-) create mode 100644 c/driver/postgresql/error.cc create mode 100644 c/driver/postgresql/error.h create mode 100644 c/driver_manager/adbc_version_100.c create mode 100644 c/driver_manager/adbc_version_100.h create mode 100644 c/driver_manager/adbc_version_100_compatibility_test.cc create mode 100644 java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java create mode 100644 java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java create mode 100644 java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java create mode 100644 java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java create mode 100644 java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java create mode 100644 java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java create mode 100644 python/adbc_driver_manager/adbc_driver_manager/_lib.pxd create mode 100644 python/adbc_driver_manager/adbc_driver_manager/_reader.pyi create mode 100644 python/adbc_driver_manager/adbc_driver_manager/_reader.pyx create mode 100644 python/adbc_driver_manager/tests/test_reader.py diff --git a/.env b/.env index 2b73edb57e..0ff75396fd 100644 --- a/.env +++ b/.env @@ -33,7 +33,7 @@ MANYLINUX=2014 MAVEN=3.5.4 PYTHON=3.10 GO=1.19.5 -ARROW_MAJOR_VERSION=12 +ARROW_MAJOR_VERSION=14 # Used through docker-compose.yml and serves as the default version for the # ci/scripts/install_vcpkg.sh script. diff --git a/.gitattributes b/.gitattributes index 7f39f6a39c..f1efc73113 100644 --- a/.gitattributes +++ b/.gitattributes @@ -16,6 +16,11 @@ # under the License. c/vendor/* linguist-vendored +go/adbc/drivermgr/adbc.h linguist-vendored +go/adbc/drivermgr/adbc_driver_manager.cc linguist-vendored +go/adbc/pkg/flightsql/* linguist-generated +go/adbc/pkg/panicdummy/* linguist-generated +go/adbc/pkg/snowflake/* linguist-generated python/adbc_driver_flightsql/adbc_driver_flightsql/_static_version.py export-subst python/adbc_driver_manager/adbc_driver_manager/_static_version.py export-subst python/adbc_driver_postgresql/adbc_driver_postgresql/_static_version.py export-subst diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index dffdab1933..b12aa46d51 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -318,6 +318,7 @@ jobs: popd - name: Go Test env: + SNOWFLAKE_DATABASE: ADBC_TESTING SNOWFLAKE_URI: ${{ secrets.SNOWFLAKE_URI }} run: | ./ci/scripts/go_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 82df862c70..188b069b56 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,12 +43,12 @@ repos: - id: cmake-format args: [--in-place] - repo: https://github.com/cpplint/cpplint - rev: 1.6.0 + rev: 1.6.1 hooks: - id: cpplint args: # From Arrow's config - - "--filter=-whitespace/comments,-readability/casting,-readability/todo,-readability/alt_tokens,-build/header_guard,-build/c++11,-build/include_order,-build/include_subdir" + - "--filter=-whitespace/comments,-whitespace/indent,-readability/braces,-readability/casting,-readability/todo,-readability/alt_tokens,-build/header_guard,-build/c++11,-build/include_order,-build/include_subdir" - "--linelength=90" - "--verbose=2" - repo: https://github.com/golangci/golangci-lint diff --git a/adbc.h b/adbc.h index 154e881255..1ec2f05080 100644 --- a/adbc.h +++ b/adbc.h @@ -35,7 +35,7 @@ /// but not concurrent access. Specific implementations may permit /// multiple threads. /// -/// \version 1.0.0 +/// \version 1.1.0 #pragma once @@ -248,7 +248,24 @@ typedef uint8_t AdbcStatusCode; /// May indicate a database-side error only. #define ADBC_STATUS_UNAUTHORIZED 14 +/// \brief Inform the driver/driver manager that we are using the extended +/// AdbcError struct from ADBC 1.1.0. +/// +/// See the AdbcError documentation for usage. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA INT32_MIN + /// \brief A detailed error message for an operation. +/// +/// The caller must zero-initialize this struct (clarified in ADBC 1.1.0). +/// +/// The structure was extended in ADBC 1.1.0. Drivers and clients using ADBC +/// 1.0.0 will not have the private_data or private_driver fields. Drivers +/// should read/write these fields if and only if vendor_code is equal to +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. Clients are required to initialize +/// this struct to avoid the possibility of uninitialized values confusing the +/// driver. struct ADBC_EXPORT AdbcError { /// \brief The error message. char* message; @@ -266,8 +283,112 @@ struct ADBC_EXPORT AdbcError { /// Unlike other structures, this is an embedded callback to make it /// easier for the driver manager and driver to cooperate. void (*release)(struct AdbcError* error); + + /// \brief Opaque implementation-defined state. + /// + /// This field may not be used unless vendor_code is + /// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. If present, this field is NULLPTR + /// iff the error is unintialized/freed. + /// + /// \since ADBC API revision 1.1.0 + void* private_data; + + /// \brief The associated driver (used by the driver manager to help + /// track state). + /// + /// This field may not be used unless vendor_code is + /// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. + /// + /// \since ADBC API revision 1.1.0 + struct AdbcDriver* private_driver; }; +#ifdef __cplusplus +/// \brief A helper to initialize the full AdbcError structure. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_INIT \ + (AdbcError{nullptr, \ + ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA, \ + {0, 0, 0, 0, 0}, \ + nullptr, \ + nullptr, \ + nullptr}) +#else +/// \brief A helper to initialize the full AdbcError structure. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_INIT \ + ((struct AdbcError){ \ + NULL, ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA, {0, 0, 0, 0, 0}, NULL, NULL, NULL}) +#endif + +/// \brief The size of the AdbcError structure in ADBC 1.0.0. +/// +/// Drivers written for ADBC 1.1.0 and later should never touch more than this +/// portion of an AdbcDriver struct when vendor_code is not +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_1_0_0_SIZE (offsetof(struct AdbcError, private_data)) +/// \brief The size of the AdbcError structure in ADBC 1.1.0. +/// +/// Drivers written for ADBC 1.1.0 and later should never touch more than this +/// portion of an AdbcDriver struct when vendor_code is +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_1_1_0_SIZE (sizeof(struct AdbcError)) + +/// \brief Extra key-value metadata for an error. +/// +/// The fields here are owned by the driver and should not be freed. The +/// fields here are invalidated when the release callback in AdbcError is +/// called. +/// +/// \since ADBC API revision 1.1.0 +struct ADBC_EXPORT AdbcErrorDetail { + /// \brief The metadata key. + const char* key; + /// \brief The binary metadata value. + const uint8_t* value; + /// \brief The length of the metadata value. + size_t value_length; +}; + +/// \brief Get the number of metadata values available in an error. +/// +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +int AdbcErrorGetDetailCount(const struct AdbcError* error); + +/// \brief Get a metadata value in an error by index. +/// +/// If index is invalid, returns an AdbcErrorDetail initialized with NULL/0 +/// fields. +/// +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index); + +/// \brief Get an ADBC error from an ArrowArrayStream created by a driver. +/// +/// This allows retrieving error details and other metadata that would +/// normally be suppressed by the Arrow C Stream Interface. +/// +/// The caller MUST NOT release the error; it is managed by the release +/// callback in the stream itself. +/// +/// \param[in] stream The stream to query. +/// \param[out] status The ADBC status code, or ADBC_STATUS_OK if there is no +/// error. Not written to if the stream does not contain an ADBC error or +/// if the pointer is NULL. +/// \return NULL if not supported. +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + /// @} /// \defgroup adbc-constants Constants @@ -279,6 +400,14 @@ struct ADBC_EXPORT AdbcError { /// point to an AdbcDriver. #define ADBC_VERSION_1_0_0 1000000 +/// \brief ADBC revision 1.1.0. +/// +/// When passed to an AdbcDriverInitFunc(), the driver parameter must +/// point to an AdbcDriver. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_VERSION_1_1_0 1001000 + /// \brief Canonical option value for enabling an option. /// /// For use as the value in SetOption calls. @@ -288,6 +417,34 @@ struct ADBC_EXPORT AdbcError { /// For use as the value in SetOption calls. #define ADBC_OPTION_VALUE_DISABLED "false" +/// \brief Canonical option name for URIs. +/// +/// Should be used as the expected option name to specify a URI for +/// any ADBC driver. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_URI "uri" +/// \brief Canonical option name for usernames. +/// +/// Should be used as the expected option name to specify a username +/// to a driver for authentication. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_USERNAME "username" +/// \brief Canonical option name for passwords. +/// +/// Should be used as the expected option name to specify a password +/// for authentication to a driver. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_PASSWORD "password" + /// \brief The database vendor/product name (e.g. the server name). /// (type: utf8). /// @@ -315,6 +472,15 @@ struct ADBC_EXPORT AdbcError { /// /// \see AdbcConnectionGetInfo #define ADBC_INFO_DRIVER_ARROW_VERSION 102 +/// \brief The driver ADBC API version (type: int64). +/// +/// The value should be one of the ADBC_VERSION constants. +/// +/// \since ADBC API revision 1.1.0 +/// \see AdbcConnectionGetInfo +/// \see ADBC_VERSION_1_0_0 +/// \see ADBC_VERSION_1_1_0 +#define ADBC_INFO_DRIVER_ADBC_VERSION 103 /// \brief Return metadata on catalogs, schemas, tables, and columns. /// @@ -337,18 +503,133 @@ struct ADBC_EXPORT AdbcError { /// \see AdbcConnectionGetObjects #define ADBC_OBJECT_DEPTH_COLUMNS ADBC_OBJECT_DEPTH_ALL +/// \defgroup adbc-table-statistics ADBC Statistic Types +/// Standard statistic names for AdbcConnectionGetStatistics. +/// @{ + +/// \brief The dictionary-encoded name of the average byte width statistic. +#define ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY 0 +/// \brief The average byte width statistic. The average size in bytes of a +/// row in the column. Value type is float64. +/// +/// For example, this is roughly the average length of a string for a string +/// column. +#define ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the distinct value count statistic. +#define ADBC_STATISTIC_DISTINCT_COUNT_KEY 1 +/// \brief The distinct value count (NDV) statistic. The number of distinct +/// values in the column. Value type is int64 (when not approximate) or +/// float64 (when approximate). +#define ADBC_STATISTIC_DISTINCT_COUNT_NAME "adbc.statistic.distinct_count" +/// \brief The dictionary-encoded name of the max byte width statistic. +#define ADBC_STATISTIC_MAX_BYTE_WIDTH_KEY 2 +/// \brief The max byte width statistic. The maximum size in bytes of a row +/// in the column. Value type is int64 (when not approximate) or float64 +/// (when approximate). +/// +/// For example, this is the maximum length of a string for a string column. +#define ADBC_STATISTIC_MAX_BYTE_WIDTH_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the max value statistic. +#define ADBC_STATISTIC_MAX_VALUE_KEY 3 +/// \brief The max value statistic. Value type is column-dependent. +#define ADBC_STATISTIC_MAX_VALUE_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the min value statistic. +#define ADBC_STATISTIC_MIN_VALUE_KEY 4 +/// \brief The min value statistic. Value type is column-dependent. +#define ADBC_STATISTIC_MIN_VALUE_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the null count statistic. +#define ADBC_STATISTIC_NULL_COUNT_KEY 5 +/// \brief The null count statistic. The number of values that are null in +/// the column. Value type is int64 (when not approximate) or float64 +/// (when approximate). +#define ADBC_STATISTIC_NULL_COUNT_NAME "adbc.statistic.null_count" +/// \brief The dictionary-encoded name of the row count statistic. +#define ADBC_STATISTIC_ROW_COUNT_KEY 6 +/// \brief The row count statistic. The number of rows in the column or +/// table. Value type is int64 (when not approximate) or float64 (when +/// approximate). +#define ADBC_STATISTIC_ROW_COUNT_NAME "adbc.statistic.row_count" +/// @} + /// \brief The name of the canonical option for whether autocommit is /// enabled. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_AUTOCOMMIT "adbc.connection.autocommit" /// \brief The name of the canonical option for whether the current /// connection should be restricted to being read-only. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_READ_ONLY "adbc.connection.readonly" +/// \brief The name of the canonical option for the current catalog. +/// +/// The type is char*. +/// +/// \see AdbcConnectionGetOption +/// \see AdbcConnectionSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_CONNECTION_OPTION_CURRENT_CATALOG "adbc.connection.catalog" + +/// \brief The name of the canonical option for the current schema. +/// +/// The type is char*. +/// +/// \see AdbcConnectionGetOption +/// \see AdbcConnectionSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA "adbc.connection.db_schema" + +/// \brief The name of the canonical option for making query execution +/// nonblocking. +/// +/// When enabled, AdbcStatementExecutePartitions will return +/// partitions as soon as they are available, instead of returning +/// them all at the end. When there are no more to return, it will +/// return an empty set of partitions. AdbcStatementExecuteQuery and +/// AdbcStatementExecuteSchema are not affected. +/// +/// The default is ADBC_OPTION_VALUE_DISABLED. +/// +/// The type is char*. +/// +/// \see AdbcStatementSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_INCREMENTAL "adbc.statement.exec.incremental" + +/// \brief The name of the option for getting the progress of a query. +/// +/// The value is not necessarily in any particular range or have any +/// particular units. (For example, it might be a percentage, bytes of data, +/// rows of data, number of workers, etc.) The max value can be retrieved via +/// ADBC_STATEMENT_OPTION_MAX_PROGRESS. This represents the progress of +/// execution, not of consumption (i.e., it is independent of how much of the +/// result set has been read by the client via ArrowArrayStream.get_next().) +/// +/// The type is double. +/// +/// \see AdbcStatementGetOptionDouble +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_PROGRESS "adbc.statement.exec.progress" + +/// \brief The name of the option for getting the maximum progress of a query. +/// +/// This is the value of ADBC_STATEMENT_OPTION_PROGRESS for a completed query. +/// If not supported, or if the value is nonpositive, then the maximum is not +/// known. (For instance, the query may be fully streaming and the driver +/// does not know when the result set will end.) +/// +/// The type is double. +/// +/// \see AdbcStatementGetOptionDouble +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_MAX_PROGRESS "adbc.statement.exec.max_progress" + /// \brief The name of the canonical option for setting the isolation /// level of a transaction. /// @@ -357,6 +638,8 @@ struct ADBC_EXPORT AdbcError { /// isolation level is not supported by a driver, it should return an /// appropriate error. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_ISOLATION_LEVEL \ "adbc.connection.transaction.isolation_level" @@ -449,8 +732,12 @@ struct ADBC_EXPORT AdbcError { /// exist. If the table exists but has a different schema, /// ADBC_STATUS_ALREADY_EXISTS should be raised. Else, data should be /// appended to the target table. +/// +/// The type is char*. #define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table" /// \brief Whether to create (the default) or append. +/// +/// The type is char*. #define ADBC_INGEST_OPTION_MODE "adbc.ingest.mode" /// \brief Create the table and insert data; error if the table exists. #define ADBC_INGEST_OPTION_MODE_CREATE "adbc.ingest.mode.create" @@ -458,6 +745,15 @@ struct ADBC_EXPORT AdbcError { /// table does not exist (ADBC_STATUS_NOT_FOUND) or does not match /// the schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). #define ADBC_INGEST_OPTION_MODE_APPEND "adbc.ingest.mode.append" +/// \brief Create the table and insert data; drop the original table +/// if it already exists. +/// \since ADBC API revision 1.1.0 +#define ADBC_INGEST_OPTION_MODE_REPLACE "adbc.ingest.mode.replace" +/// \brief Insert data; create the table if it does not exist, or +/// error if the table exists, but the schema does not match the +/// schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). +/// \since ADBC API revision 1.1.0 +#define ADBC_INGEST_OPTION_MODE_CREATE_APPEND "adbc.ingest.mode.create_append" /// @} @@ -624,7 +920,7 @@ struct ADBC_EXPORT AdbcDriver { AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*); AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection*, struct AdbcError*); - AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t, + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, const uint32_t*, size_t, struct ArrowArrayStream*, struct AdbcError*); AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection*, int, const char*, const char*, const char*, const char**, @@ -667,8 +963,108 @@ struct ADBC_EXPORT AdbcDriver { struct AdbcError*); AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const uint8_t*, size_t, struct AdbcError*); + + /// \defgroup adbc-1.1.0 ADBC API Revision 1.1.0 + /// + /// Functions added in ADBC 1.1.0. For backwards compatibility, + /// these members must not be accessed unless the version passed to + /// the AdbcDriverInitFunc is greater than or equal to + /// ADBC_VERSION_1_1_0. + /// + /// For a 1.0.0 driver being loaded by a 1.1.0 driver manager: the + /// 1.1.0 manager will allocate the new, expanded AdbcDriver struct + /// and attempt to have the driver initialize it with + /// ADBC_VERSION_1_1_0. This must return an error, after which the + /// driver will try again with ADBC_VERSION_1_0_0. The driver must + /// not access the new fields, which will carry undefined values. + /// + /// For a 1.1.0 driver being loaded by a 1.0.0 driver manager: the + /// 1.0.0 manager will allocate the old AdbcDriver struct and + /// attempt to have the driver initialize it with + /// ADBC_VERSION_1_0_0. The driver must not access the new fields, + /// and should initialize the old fields. + /// + /// @{ + + int (*ErrorGetDetailCount)(const struct AdbcError* error); + struct AdbcErrorDetail (*ErrorGetDetail)(const struct AdbcError* error, int index); + const struct AdbcError* (*ErrorFromArrayStream)(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + + AdbcStatusCode (*DatabaseGetOption)(struct AdbcDatabase*, const char*, char*, size_t*, + struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionBytes)(struct AdbcDatabase*, const char*, uint8_t*, + size_t*, struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionDouble)(struct AdbcDatabase*, const char*, double*, + struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionInt)(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionBytes)(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionDouble)(struct AdbcDatabase*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionInt)(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + + AdbcStatusCode (*ConnectionCancel)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOption)(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionBytes)(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionDouble)(struct AdbcConnection*, const char*, + double*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionInt)(struct AdbcConnection*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*ConnectionGetStatistics)(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetStatisticNames)(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionBytes)(struct AdbcConnection*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionDouble)(struct AdbcConnection*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionInt)(struct AdbcConnection*, const char*, int64_t, + struct AdbcError*); + + AdbcStatusCode (*StatementCancel)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementExecuteSchema)(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOption)(struct AdbcStatement*, const char*, char*, size_t*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOptionBytes)(struct AdbcStatement*, const char*, uint8_t*, + size_t*, struct AdbcError*); + AdbcStatusCode (*StatementGetOptionDouble)(struct AdbcStatement*, const char*, double*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOptionInt)(struct AdbcStatement*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*StatementSetOptionBytes)(struct AdbcStatement*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*StatementSetOptionDouble)(struct AdbcStatement*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*StatementSetOptionInt)(struct AdbcStatement*, const char*, int64_t, + struct AdbcError*); + + /// @} }; +/// \brief The size of the AdbcDriver structure in ADBC 1.0.0. +/// Drivers written for ADBC 1.1.0 and later should never touch more +/// than this portion of an AdbcDriver struct when given +/// ADBC_VERSION_1_0_0. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_DRIVER_1_0_0_SIZE (offsetof(struct AdbcDriver, ErrorGetDetailCount)) + +/// \brief The size of the AdbcDriver structure in ADBC 1.1.0. +/// Drivers written for ADBC 1.1.0 and later should never touch more +/// than this portion of an AdbcDriver struct when given +/// ADBC_VERSION_1_1_0. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_DRIVER_1_1_0_SIZE (sizeof(struct AdbcDriver)) + /// @} /// \addtogroup adbc-database @@ -684,16 +1080,189 @@ struct ADBC_EXPORT AdbcDriver { ADBC_EXPORT AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error); +/// \brief Get a string option of the database. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the database. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a double option of the database. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the double +/// representation of an integer option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error); + +/// \brief Get an integer option of the database. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the integer +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error); + /// \brief Set a char* option. /// /// Options may be set before AdbcDatabaseInit. Some drivers may /// support setting options after initialization as well. /// +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. /// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized ADBC_EXPORT AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error); + +/// \brief Set a double option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error); + +/// \brief Set an integer option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error); + /// \brief Finish setting options and initialize the database. /// /// Some drivers may support setting options after initialization @@ -730,11 +1299,65 @@ AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, /// Options may be set before AdbcConnectionInit. Some drivers may /// support setting options after initialization as well. /// +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. /// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized ADBC_EXPORT AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a connection. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error); + +/// \brief Set an integer option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error); + +/// \brief Set a double option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error); + /// \brief Finish setting options and initialize the connection. /// /// Some drivers may support setting options after initialization @@ -752,6 +1375,30 @@ ADBC_EXPORT AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, struct AdbcError* error); +/// \brief Cancel the in-progress operation on a connection. +/// +/// This can be called during AdbcConnectionGetObjects (or similar), +/// or while consuming an ArrowArrayStream returned from such. +/// Calling this function should make the other functions return +/// ADBC_STATUS_CANCELLED (from ADBC functions) or ECANCELED (from +/// methods of ArrowArrayStream). (It is not guaranteed to, for +/// instance, the result set may be buffered in memory already.) +/// +/// This must always be thread-safe (other operations are not). It is +/// not necessarily signal-safe. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] connection The connection to cancel. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_INVALID_STATE if there is no operation to cancel. +/// \return ADBC_STATUS_UNKNOWN if the operation could not be cancelled. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error); + /// \defgroup adbc-connection-metadata Metadata /// Functions for retrieving metadata about the database. /// @@ -765,6 +1412,8 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// concurrent active statements and it must execute a SQL query /// internally in order to implement the metadata function). /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// Some functions accept "search pattern" arguments, which are /// strings that can contain the special character "%" to match zero /// or more characters, or "_" to match exactly one character. (See @@ -799,6 +1448,10 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// for ADBC usage. Drivers/vendors will ignore requests for /// unrecognized codes (the row will be omitted from the result). /// +/// Since ADBC 1.1.0: the range [500, 1_000) is reserved for "XDBC" +/// information, which is the same metadata provided by the same info +/// code range in the Arrow Flight SQL GetSqlInfo RPC. +/// /// \param[in] connection The connection to query. /// \param[in] info_codes A list of metadata codes to fetch, or NULL /// to fetch all. @@ -808,7 +1461,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// \param[out] error Error details, if an error occurs. ADBC_EXPORT AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error); @@ -891,6 +1544,8 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, /// | fk_table | utf8 not null | /// | fk_column_name | utf8 not null | /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The database connection. /// \param[in] depth The level of nesting to display. If 0, display /// all levels. If 1, display only catalogs (i.e. catalog_schemas @@ -922,6 +1577,212 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct ArrowArrayStream* out, struct AdbcError* error); +/// \brief Get a string option of the connection. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the connection. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error); + +/// \brief Get an integer option of the connection. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error); + +/// \brief Get a double option of the connection. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error); + +/// \brief Get statistics about the data distribution of table(s). +/// +/// The result is an Arrow dataset with the following schema: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | catalog_name | utf8 | +/// | catalog_db_schemas | list not null | +/// +/// DB_SCHEMA_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | db_schema_name | utf8 | +/// | db_schema_statistics | list not null | +/// +/// STATISTICS_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | Comments | +/// |--------------------------|----------------------------------| -------- | +/// | table_name | utf8 not null | | +/// | column_name | utf8 | (1) | +/// | statistic_key | int16 not null | (2) | +/// | statistic_value | VALUE_SCHEMA not null | | +/// | statistic_is_approximate | bool not null | (3) | +/// +/// 1. If null, then the statistic applies to the entire table. +/// 2. A dictionary-encoded statistic name (although we do not use the Arrow +/// dictionary type). Values in [0, 1024) are reserved for ADBC. Other +/// values are for implementation-specific statistics. For the definitions +/// of predefined statistic types, see \ref adbc-table-statistics. To get +/// driver-specific statistic names, use AdbcConnectionGetStatisticNames. +/// 3. If true, then the value is approximate or best-effort. +/// +/// VALUE_SCHEMA is a dense union with members: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | int64 | int64 | +/// | uint64 | uint64 | +/// | float64 | float64 | +/// | binary | binary | +/// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] catalog The catalog (or nullptr). May be a search +/// pattern (see section documentation). +/// \param[in] db_schema The database schema (or nullptr). May be a +/// search pattern (see section documentation). +/// \param[in] table_name The table name (or nullptr). May be a +/// search pattern (see section documentation). +/// \param[in] approximate If zero, request exact values of +/// statistics, else allow for best-effort, approximate, or cached +/// values. The database may return approximate values regardless, +/// as indicated in the result. Requesting exact values may be +/// expensive or unsupported. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error); + +/// \brief Get the names of statistics specific to this driver. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// Field Name | Field Type +/// ---------------|---------------- +/// statistic_name | utf8 not null +/// statistic_key | int16 not null +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error); + /// \brief Get the Arrow schema of a table. /// /// \param[in] connection The database connection. @@ -945,6 +1806,8 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, /// ---------------|-------------- /// table_type | utf8 not null /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The database connection. /// \param[out] out The result set. /// \param[out] error Error details, if an error occurs. @@ -973,6 +1836,8 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, /// /// A partition can be retrieved from AdbcPartitions. /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The connection to use. This does not have /// to be the same connection that the partition was created on. /// \param[in] serialized_partition The partition descriptor. @@ -1042,7 +1907,11 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, /// \brief Execute a statement and get the results. /// -/// This invalidates any prior result sets. +/// This invalidates any prior result sets. This AdbcStatement must +/// outlive the returned ArrowArrayStream. +/// +/// Since ADBC 1.1.0: releasing the returned ArrowArrayStream without +/// consuming it fully is equivalent to calling AdbcStatementCancel. /// /// \param[in] statement The statement to execute. /// \param[out] out The results. Pass NULL if the client does not @@ -1056,6 +1925,27 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, struct ArrowArrayStream* out, int64_t* rows_affected, struct AdbcError* error); +/// \brief Get the schema of the result set of a query without +/// executing it. +/// +/// This invalidates any prior result sets. +/// +/// Depending on the driver, this may require first executing +/// AdbcStatementPrepare. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to execute. +/// \param[out] out The result schema. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the driver does not support this. +ADBC_EXPORT +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error); + /// \brief Turn this statement into a prepared statement to be /// executed multiple times. /// @@ -1138,6 +2028,158 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, struct ArrowArrayStream* stream, struct AdbcError* error); +/// \brief Cancel execution of an in-progress query. +/// +/// This can be called during AdbcStatementExecuteQuery (or similar), +/// or while consuming an ArrowArrayStream returned from such. +/// Calling this function should make the other functions return +/// ADBC_STATUS_CANCELLED (from ADBC functions) or ECANCELED (from +/// methods of ArrowArrayStream). (It is not guaranteed to, for +/// instance, the result set may be buffered in memory already.) +/// +/// This must always be thread-safe (other operations are not). It is +/// not necessarily signal-safe. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to cancel. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_INVALID_STATE if there is no query to cancel. +/// \return ADBC_STATUS_UNKNOWN if the query could not be cancelled. +ADBC_EXPORT +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error); + +/// \brief Get a string option of the statement. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the statement. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error); + +/// \brief Get an integer option of the statement. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error); + +/// \brief Get a double option of the statement. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error); + /// \brief Get the schema for bound parameters. /// /// This retrieves an Arrow schema describing the number, names, and @@ -1159,10 +2201,58 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct AdbcError* error); /// \brief Set a string option on a statement. +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized. ADBC_EXPORT AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error); + +/// \brief Set an integer option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error); + +/// \brief Set a double option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error); + /// \addtogroup adbc-statement-partition /// @{ @@ -1198,7 +2288,15 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, /// driver. /// /// Although drivers may choose any name for this function, the -/// recommended name is "AdbcDriverInit". +/// recommended name is "AdbcDriverInit", or a name derived from the +/// name of the driver's shared library as follows: remove the 'lib' +/// prefix (on Unix systems) and all file extensions, then PascalCase +/// the driver name, append Init, and prepend Adbc (if not already +/// there). For example: +/// +/// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit +/// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit +/// - proprietary_driver.dll -> AdbcProprietaryDriverInit /// /// \param[in] version The ADBC revision to attempt to initialize (see /// ADBC_VERSION_1_0_0). diff --git a/c/driver/common/utils.c b/c/driver/common/utils.c index dfac14f5e4..71d9e7ef01 100644 --- a/c/driver/common/utils.c +++ b/c/driver/common/utils.c @@ -17,15 +17,80 @@ #include "utils.h" +#include #include -#include #include #include #include -#include +#include + +static size_t kErrorBufferSize = 1024; + +int AdbcStatusCodeToErrno(AdbcStatusCode code) { + switch (code) { + case ADBC_STATUS_OK: + return 0; + case ADBC_STATUS_UNKNOWN: + return EIO; + case ADBC_STATUS_NOT_IMPLEMENTED: + return ENOTSUP; + case ADBC_STATUS_NOT_FOUND: + return ENOENT; + case ADBC_STATUS_ALREADY_EXISTS: + return EEXIST; + case ADBC_STATUS_INVALID_ARGUMENT: + case ADBC_STATUS_INVALID_STATE: + return EINVAL; + case ADBC_STATUS_INVALID_DATA: + case ADBC_STATUS_INTEGRITY: + case ADBC_STATUS_INTERNAL: + case ADBC_STATUS_IO: + return EIO; + case ADBC_STATUS_CANCELLED: + return ECANCELED; + case ADBC_STATUS_TIMEOUT: + return ETIMEDOUT; + case ADBC_STATUS_UNAUTHENTICATED: + // FreeBSD/macOS have EAUTH, but not other platforms + case ADBC_STATUS_UNAUTHORIZED: + return EACCES; + default: + return EIO; + } +} + +/// For ADBC 1.1.0, the structure held in private_data. +struct AdbcErrorDetails { + char* message; + + // The metadata keys (may be NULL). + char** keys; + // The metadata values (may be NULL). + uint8_t** values; + // The metadata value lengths (may be NULL). + size_t* lengths; + // The number of initialized metadata. + int count; + // The length of the keys/values/lengths arrays above. + int capacity; +}; + +static void ReleaseErrorWithDetails(struct AdbcError* error) { + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + free(details->message); -static size_t kErrorBufferSize = 256; + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + + free(details->keys); + free(details->values); + free(details->lengths); + free(error->private_data); + *error = ADBC_ERROR_INIT; +} static void ReleaseError(struct AdbcError* error) { free(error->message); @@ -34,20 +99,126 @@ static void ReleaseError(struct AdbcError* error) { } void SetError(struct AdbcError* error, const char* format, ...) { + va_list args; + va_start(args, format); + SetErrorVariadic(error, format, args); + va_end(args); +} + +void SetErrorVariadic(struct AdbcError* error, const char* format, va_list args) { if (!error) return; if (error->release) { // TODO: combine the errors if possible error->release(error); } - error->message = malloc(kErrorBufferSize); - if (!error->message) return; - error->release = &ReleaseError; + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + error->private_data = malloc(sizeof(struct AdbcErrorDetails)); + if (!error->private_data) return; + + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + + details->message = malloc(kErrorBufferSize); + if (!details->message) { + free(details); + return; + } + details->keys = NULL; + details->values = NULL; + details->lengths = NULL; + details->count = 0; + details->capacity = 0; + + error->message = details->message; + error->release = &ReleaseErrorWithDetails; + } else { + error->message = malloc(kErrorBufferSize); + if (!error->message) return; + + error->release = &ReleaseError; + } - va_list args; - va_start(args, format); vsnprintf(error->message, kErrorBufferSize, format, args); - va_end(args); +} + +void AppendErrorDetail(struct AdbcError* error, const char* key, const uint8_t* detail, + size_t detail_length) { + if (error->release != ReleaseErrorWithDetails) return; + + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + if (details->count >= details->capacity) { + int new_capacity = (details->capacity == 0) ? 4 : (2 * details->capacity); + char** new_keys = calloc(new_capacity, sizeof(char*)); + if (!new_keys) { + return; + } + + uint8_t** new_values = calloc(new_capacity, sizeof(uint8_t*)); + if (!new_values) { + free(new_keys); + return; + } + + size_t* new_lengths = calloc(new_capacity, sizeof(size_t*)); + if (!new_lengths) { + free(new_keys); + free(new_values); + return; + } + + memcpy(new_keys, details->keys, sizeof(char*) * details->count); + free(details->keys); + details->keys = new_keys; + + memcpy(new_values, details->values, sizeof(uint8_t*) * details->count); + free(details->values); + details->values = new_values; + + memcpy(new_lengths, details->lengths, sizeof(size_t) * details->count); + free(details->lengths); + details->lengths = new_lengths; + + details->capacity = new_capacity; + } + + char* key_data = strdup(key); + if (!key_data) return; + uint8_t* value_data = malloc(detail_length); + if (!value_data) { + free(key_data); + return; + } + memcpy(value_data, detail, detail_length); + + int index = details->count; + details->keys[index] = key_data; + details->values[index] = value_data; + details->lengths[index] = detail_length; + + details->count++; +} + +int CommonErrorGetDetailCount(const struct AdbcError* error) { + if (error->release != ReleaseErrorWithDetails) { + return 0; + } + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + return details->count; +} + +struct AdbcErrorDetail CommonErrorGetDetail(const struct AdbcError* error, int index) { + if (error->release != ReleaseErrorWithDetails) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index], + }; } struct SingleBatchArrayStream { @@ -244,6 +415,19 @@ AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array, return ADBC_STATUS_OK; } +AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, + uint32_t info_code, int64_t info_value, + struct AdbcError* error) { + CHECK_NA(INTERNAL, ArrowArrayAppendUInt(array->children[0], info_code), error); + // Append to type variant + CHECK_NA(INTERNAL, ArrowArrayAppendInt(array->children[1]->children[2], info_value), + error); + // Append type code/offset + CHECK_NA(INTERNAL, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/2), + error); + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema, struct AdbcError* error) { ArrowSchemaInit(schema); diff --git a/c/driver/common/utils.h b/c/driver/common/utils.h index 5735bb945f..e3d81cb0f6 100644 --- a/c/driver/common/utils.h +++ b/c/driver/common/utils.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -27,6 +28,8 @@ extern "C" { #endif +int AdbcStatusCodeToErrno(AdbcStatusCode code); + // The printf checking attribute doesn't work properly on gcc 4.8 // and results in spurious compiler warnings #if defined(__clang__) || (defined(__GNUC__) && __GNUC__ >= 5) @@ -35,10 +38,20 @@ extern "C" { #define ADBC_CHECK_PRINTF_ATTRIBUTE #endif -/// Set error details using a format string. +/// Set error message using a format string. void SetError(struct AdbcError* error, const char* format, ...) ADBC_CHECK_PRINTF_ATTRIBUTE; +/// Set error message using a format string. +void SetErrorVariadic(struct AdbcError* error, const char* format, va_list args); + +/// Add an error detail. +void AppendErrorDetail(struct AdbcError* error, const char* key, const uint8_t* detail, + size_t detail_length); + +int CommonErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail CommonErrorGetDetail(const struct AdbcError* error, int index); + struct StringBuilder { char* buffer; // Not including null terminator @@ -117,6 +130,9 @@ AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array, uint32_t info_code, const char* info_value, struct AdbcError* error); +AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, + uint32_t info_code, int64_t info_value, + struct AdbcError* error); AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema, struct AdbcError* error); diff --git a/c/driver/common/utils_test.cc b/c/driver/common/utils_test.cc index 6fa7e254df..d5c202bf2e 100644 --- a/c/driver/common/utils_test.cc +++ b/c/driver/common/utils_test.cc @@ -15,6 +15,12 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include +#include + +#include #include #include "utils.h" @@ -72,3 +78,92 @@ TEST(TestStringBuilder, TestMultipleAppends) { StringBuilderReset(&str); } + +TEST(ErrorDetails, Adbc100) { + struct AdbcError error; + std::memset(&error, 0, ADBC_ERROR_1_1_0_SIZE); + + SetError(&error, "My message"); + + ASSERT_EQ(nullptr, error.private_data); + ASSERT_EQ(nullptr, error.private_driver); + + { + std::string detail = "detail"; + AppendErrorDetail(&error, "key", reinterpret_cast(detail.data()), + detail.size()); + } + + ASSERT_EQ(0, CommonErrorGetDetailCount(&error)); + struct AdbcErrorDetail detail = CommonErrorGetDetail(&error, 0); + ASSERT_EQ(nullptr, detail.key); + ASSERT_EQ(nullptr, detail.value); + ASSERT_EQ(0, detail.value_length); + + error.release(&error); +} + +TEST(ErrorDetails, Adbc110) { + struct AdbcError error = ADBC_ERROR_INIT; + SetError(&error, "My message"); + + ASSERT_NE(nullptr, error.private_data); + ASSERT_EQ(nullptr, error.private_driver); + + { + std::string detail = "detail"; + AppendErrorDetail(&error, "key", reinterpret_cast(detail.data()), + detail.size()); + } + + ASSERT_EQ(1, CommonErrorGetDetailCount(&error)); + struct AdbcErrorDetail detail = CommonErrorGetDetail(&error, 0); + ASSERT_STREQ("key", detail.key); + ASSERT_EQ("detail", std::string_view(reinterpret_cast(detail.value), + detail.value_length)); + + detail = CommonErrorGetDetail(&error, -1); + ASSERT_EQ(nullptr, detail.key); + ASSERT_EQ(nullptr, detail.value); + ASSERT_EQ(0, detail.value_length); + + detail = CommonErrorGetDetail(&error, 2); + ASSERT_EQ(nullptr, detail.key); + ASSERT_EQ(nullptr, detail.value); + ASSERT_EQ(0, detail.value_length); + + error.release(&error); + ASSERT_EQ(nullptr, error.private_data); + ASSERT_EQ(nullptr, error.private_driver); +} + +TEST(ErrorDetails, RoundTripValues) { + struct AdbcError error = ADBC_ERROR_INIT; + SetError(&error, "My message"); + + struct Detail { + std::string key; + std::vector value; + }; + + std::vector details = { + {"x-key-1", {0, 1, 2, 3}}, {"x-key-2", {1, 1}}, {"x-key-3", {128, 129, 200, 0, 1}}, + {"x-key-4", {97, 98, 99}}, {"x-key-5", {42}}, + }; + + for (const auto& detail : details) { + AppendErrorDetail(&error, detail.key.c_str(), detail.value.data(), + detail.value.size()); + } + + ASSERT_EQ(details.size(), CommonErrorGetDetailCount(&error)); + for (int i = 0; i < static_cast(details.size()); i++) { + struct AdbcErrorDetail detail = CommonErrorGetDetail(&error, i); + ASSERT_EQ(details[i].key, detail.key); + ASSERT_EQ(details[i].value.size(), detail.value_length); + ASSERT_THAT(std::vector(detail.value, detail.value + detail.value_length), + ::testing::ElementsAreArray(details[i].value)); + } + + error.release(&error); +} diff --git a/c/driver/flightsql/dremio_flightsql_test.cc b/c/driver/flightsql/dremio_flightsql_test.cc index c128bd49f3..416b8aeaf5 100644 --- a/c/driver/flightsql/dremio_flightsql_test.cc +++ b/c/driver/flightsql/dremio_flightsql_test.cc @@ -42,11 +42,11 @@ class DremioFlightSqlQuirks : public adbc_validation::DriverQuirks { } std::string BindParameter(int index) const override { return "?"; } + bool supports_bulk_ingest(const char* /*mode*/) const override { return false; } bool supports_concurrent_statements() const override { return true; } bool supports_transactions() const override { return false; } bool supports_get_sql_info() const override { return false; } bool supports_get_objects() const override { return true; } - bool supports_bulk_ingest() const override { return false; } bool supports_partitioned_data() const override { return true; } bool supports_dynamic_parameter_binding() const override { return false; } }; @@ -87,6 +87,7 @@ class DremioFlightSqlStatementTest : public ::testing::Test, void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); } void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); } + void TestResultInvalidation() { GTEST_SKIP() << "Dremio generates a CANCELLED"; } void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; } protected: diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc b/c/driver/flightsql/sqlite_flightsql_test.cc index b61b47bc6f..46ca69be4a 100644 --- a/c/driver/flightsql/sqlite_flightsql_test.cc +++ b/c/driver/flightsql/sqlite_flightsql_test.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include #include @@ -32,6 +33,10 @@ using adbc_validation::IsOkErrno; using adbc_validation::IsOkStatus; +extern "C" { +AdbcStatusCode FlightSQLDriverInit(int, void*, struct AdbcError*); +} + #define CHECK_OK(EXPR) \ do { \ if (auto adbc_status = (EXPR); adbc_status != ADBC_STATUS_OK) { \ @@ -89,11 +94,31 @@ class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks { } std::string BindParameter(int index) const override { return "?"; } + + bool supports_bulk_ingest(const char* /*mode*/) const override { return false; } bool supports_concurrent_statements() const override { return true; } bool supports_transactions() const override { return false; } bool supports_get_sql_info() const override { return true; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_NAME: + return "ADBC Flight SQL Driver - Go"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown or development build)"; + case ADBC_INFO_DRIVER_ADBC_VERSION: + return ADBC_VERSION_1_1_0; + case ADBC_INFO_VENDOR_NAME: + return "db_name"; + case ADBC_INFO_VENDOR_VERSION: + return "sqlite 3"; + case ADBC_INFO_VENDOR_ARROW_VERSION: + return "12.0.0"; + default: + return std::nullopt; + } + } bool supports_get_objects() const override { return true; } - bool supports_bulk_ingest() const override { return false; } bool supports_partitioned_data() const override { return true; } bool supports_dynamic_parameter_binding() const override { return true; } }; @@ -209,6 +234,20 @@ TEST_F(SqliteFlightSqlTest, TestGarbageInput) { ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error)); } +TEST_F(SqliteFlightSqlTest, AdbcDriverBackwardsCompatibility) { + // XXX: sketchy cast + auto* driver = static_cast(malloc(ADBC_DRIVER_1_0_0_SIZE)); + std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + + ASSERT_THAT(::FlightSQLDriverInit(ADBC_VERSION_1_0_0, driver, &error), + IsOkStatus(&error)); + + ASSERT_THAT(::FlightSQLDriverInit(424242, driver, &error), + adbc_validation::IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &error)); + + free(driver); +} + class SqliteFlightSqlConnectionTest : public ::testing::Test, public adbc_validation::ConnectionTest { public: @@ -237,3 +276,139 @@ class SqliteFlightSqlStatementTest : public ::testing::Test, SqliteFlightSqlQuirks quirks_; }; ADBCV_TEST_STATEMENT(SqliteFlightSqlStatementTest) + +// Test what happens when using the ADBC 1.1.0 error structure +TEST_F(SqliteFlightSqlStatementTest, NonexistentTable) { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, + "SELECT * FROM tabledoesnotexist", &error), + IsOkStatus(&error)); + + for (auto vendor_code : {0, ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA}) { + error.vendor_code = vendor_code; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error)); + ASSERT_EQ(0, AdbcErrorGetDetailCount(&error)); + error.release(&error); + } +} + +TEST_F(SqliteFlightSqlStatementTest, CancelError) { + // Ensure cancellation propagates properly through the Go FFI boundary + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "SELECT 1", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, + &reader.rows_affected, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementCancel(&statement.value, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + int retcode = 0; + while (true) { + retcode = reader.MaybeNext(); + if (retcode != 0 || !reader.array->release) break; + } + + ASSERT_EQ(ECANCELED, retcode); + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* adbc_error = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, adbc_error); + ASSERT_EQ(ADBC_STATUS_CANCELLED, status); +} + +TEST_F(SqliteFlightSqlStatementTest, RpcError) { + // Ensure errors that happen at the start of the stream propagate properly + // through the Go FFI boundary + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "SELECT", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, + &reader.rows_affected, &error), + adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error)); + + int count = AdbcErrorGetDetailCount(&error); + ASSERT_NE(0, count); + for (int i = 0; i < count; i++) { + struct AdbcErrorDetail detail = AdbcErrorGetDetail(&error, i); + ASSERT_NE(nullptr, detail.key); + ASSERT_NE(nullptr, detail.value); + ASSERT_NE(0, detail.value_length); + EXPECT_STREQ("afsql-sqlite-query", detail.key); + EXPECT_EQ("SELECT", std::string_view(reinterpret_cast(detail.value), + detail.value_length)); + } +} + +TEST_F(SqliteFlightSqlStatementTest, StreamError) { + // Ensure errors that happen during the stream propagate properly through + // the Go FFI boundary + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, + R"( +DROP TABLE IF EXISTS foo; +CREATE TABLE foo (a INT); +WITH RECURSIVE sequence(x) AS + (SELECT 1 UNION ALL SELECT x+1 FROM sequence LIMIT 1024) +INSERT INTO foo(a) +SELECT x FROM sequence; +INSERT INTO foo(a) VALUES ('foo');)", + &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "SELECT * FROM foo", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, + &reader.rows_affected, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + int retcode = 0; + while (true) { + retcode = reader.MaybeNext(); + if (retcode != 0 || !reader.array->release) break; + } + + ASSERT_NE(0, retcode); + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* adbc_error = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, adbc_error); + ASSERT_EQ(ADBC_STATUS_UNKNOWN, status); + + int count = AdbcErrorGetDetailCount(adbc_error); + ASSERT_NE(0, count); + for (int i = 0; i < count; i++) { + struct AdbcErrorDetail detail = AdbcErrorGetDetail(adbc_error, i); + ASSERT_NE(nullptr, detail.key); + ASSERT_NE(nullptr, detail.value); + ASSERT_NE(0, detail.value_length); + EXPECT_STREQ("grpc-status-details-bin", detail.key); + } +} diff --git a/c/driver/postgresql/CMakeLists.txt b/c/driver/postgresql/CMakeLists.txt index d16979419a..9cf595a386 100644 --- a/c/driver/postgresql/CMakeLists.txt +++ b/c/driver/postgresql/CMakeLists.txt @@ -29,6 +29,7 @@ endif() add_arrow_lib(adbc_driver_postgresql SOURCES connection.cc + error.cc database.cc postgresql.cc statement.cc diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc index 08ff9027c3..c106ffca78 100644 --- a/c/driver/postgresql/connection.cc +++ b/c/driver/postgresql/connection.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -32,29 +33,43 @@ #include "common/utils.h" #include "database.h" +#include "error.h" +namespace adbcpq { namespace { static const uint32_t kSupportedInfoCodes[] = { - ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION, + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, + ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_DRIVER_ARROW_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION, }; static const std::unordered_map kPgTableTypes = { {"table", "r"}, {"view", "v"}, {"materialized_view", "m"}, {"toast_table", "t"}, {"foreign_table", "f"}, {"partitioned_table", "p"}}; +/// \brief A single column in a single row of a result set. struct PqRecord { const char* data; const int len; const bool is_null; + + // XXX: can't use optional due to R + std::pair ParseDouble() const { + char* end; + double result = std::strtod(data, &end); + if (errno != 0 || end == data) { + return std::make_pair(false, 0.0); + } + return std::make_pair(true, result); + } }; // Used by PqResultHelper to provide index-based access to the records within each -// row of a pg_result +// row of a PGresult class PqResultRow { public: - PqResultRow(pg_result* result, int row_num) : result_(result), row_num_(row_num) { + PqResultRow(PGresult* result, int row_num) : result_(result), row_num_(row_num) { ncols_ = PQnfields(result); } @@ -68,7 +83,7 @@ class PqResultRow { } private: - pg_result* result_ = nullptr; + PGresult* result_ = nullptr; int row_num_; int ncols_; }; @@ -94,10 +109,11 @@ class PqResultHelper { PGresult* result = PQprepare(conn_, /*stmtName=*/"", query_.c_str(), param_values_.size(), NULL); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error_, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(conn_), query_.c_str()); + AdbcStatusCode code = + SetError(error_, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", + PQerrorMessage(conn_), query_.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); @@ -114,9 +130,12 @@ class PqResultHelper { result_ = PQexecPrepared(conn_, "", param_values_.size(), param_c_strs.data(), NULL, NULL, 0); - if (PQresultStatus(result_) != PGRES_TUPLES_OK) { - SetError(error_, "[libpq] Failed to execute query: %s", PQerrorMessage(conn_)); - return ADBC_STATUS_IO; + ExecStatusType status = PQresultStatus(result_); + if (status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK) { + AdbcStatusCode error = + SetError(error_, result_, "[libpq] Failed to execute query '%s': %s", + query_.c_str(), PQerrorMessage(conn_)); + return error; } return ADBC_STATUS_OK; @@ -164,7 +183,7 @@ class PqResultHelper { iterator end() { return iterator(*this, NumRows()); } private: - pg_result* result_ = nullptr; + PGresult* result_ = nullptr; PGconn* conn_; std::string query_; std::vector param_values_; @@ -727,7 +746,19 @@ class PqGetObjectsHelper { } // namespace -namespace adbcpq { +AdbcStatusCode PostgresConnection::Cancel(struct AdbcError* error) { + // > errbuf must be a char array of size errbufsize (the recommended size is + // > 256 bytes). + // https://www.postgresql.org/docs/current/libpq-cancel.html + char errbuf[256]; + // > The return value is 1 if the cancel request was successfully dispatched + // > and 0 if not. + if (PQcancel(cancel_, errbuf, sizeof(errbuf)) != 1) { + SetError(error, "[libpq] Failed to cancel operation: %s", errbuf); + return ADBC_STATUS_UNKNOWN; + } + return ADBC_STATUS_OK; +} AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) { if (autocommit_) { @@ -737,9 +768,10 @@ AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) { PGresult* result = PQexec(conn_, "COMMIT"); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "%s%s", "[libpq] Failed to commit: ", PQerrorMessage(conn_)); + AdbcStatusCode code = SetError(error, result, "%s%s", + "[libpq] Failed to commit: ", PQerrorMessage(conn_)); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); return ADBC_STATUS_OK; @@ -776,6 +808,10 @@ AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, RAISE_ADBC(AdbcConnectionGetInfoAppendString(array, info_codes[i], NANOARROW_VERSION, error)); break; + case ADBC_INFO_DRIVER_ADBC_VERSION: + RAISE_ADBC(AdbcConnectionGetInfoAppendInt(array, info_codes[i], + ADBC_VERSION_1_1_0, error)); + break; default: // Ignore continue; @@ -791,13 +827,12 @@ AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, } AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { - // XXX: mistake in adbc.h (should have been const pointer) - const uint32_t* codes = info_codes; if (!info_codes) { - codes = kSupportedInfoCodes; + info_codes = kSupportedInfoCodes; info_codes_length = sizeof(kSupportedInfoCodes) / sizeof(kSupportedInfoCodes[0]); } @@ -806,8 +841,8 @@ AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, struct ArrowArray array; std::memset(&array, 0, sizeof(array)); - AdbcStatusCode status = - PostgresConnectionGetInfoImpl(codes, info_codes_length, &schema, &array, error); + AdbcStatusCode status = PostgresConnectionGetInfoImpl(info_codes, info_codes_length, + &schema, &array, error); if (status != ADBC_STATUS_OK) { if (schema.release) schema.release(&schema); if (array.release) array.release(&array); @@ -840,6 +875,399 @@ AdbcStatusCode PostgresConnection::GetObjects( return BatchToArrayStream(&array, &schema, out, error); } +AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value, + size_t* length, struct AdbcError* error) { + std::string output; + if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) { + output = PQdb(conn_); + } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { + PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA", {}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + auto it = result_helper.begin(); + if (it == result_helper.end()) { + SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + return ADBC_STATUS_INTERNAL; + } + output = (*it)[0].data; + } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { + output = autocommit_ ? ADBC_OPTION_VALUE_ENABLED : ADBC_OPTION_VALUE_DISABLED; + } else { + return ADBC_STATUS_NOT_FOUND; + } + + if (output.size() + 1 <= *length) { + std::memcpy(value, output.c_str(), output.size() + 1); + } + *length = output.size() + 1; + return ADBC_STATUS_OK; +} +AdbcStatusCode PostgresConnection::GetOptionBytes(const char* option, uint8_t* value, + size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresConnection::GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresConnection::GetOptionDouble(const char* option, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode PostgresConnectionGetStatisticsImpl(PGconn* conn, const char* db_schema, + const char* table_name, + struct ArrowSchema* schema, + struct ArrowArray* array, + struct AdbcError* error) { + // Set up schema + auto uschema = nanoarrow::UniqueSchema(); + { + ArrowSchemaInit(uschema.get()); + CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), /*num_columns=*/2), error); + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema->children[0], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema->children[0], "catalog_name"), error); + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema->children[1], NANOARROW_TYPE_LIST), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema->children[1], "catalog_db_schemas"), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema->children[1]->children[0], 2), + error); + uschema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + + struct ArrowSchema* db_schema_schema = uschema->children[1]->children[0]; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(db_schema_schema->children[0], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(db_schema_schema->children[0], "db_schema_name"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(db_schema_schema->children[1], NANOARROW_TYPE_LIST), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(db_schema_schema->children[1], "db_schema_statistics"), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetTypeStruct(db_schema_schema->children[1]->children[0], 5), + error); + db_schema_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + + struct ArrowSchema* statistics_schema = db_schema_schema->children[1]->children[0]; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[0], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(statistics_schema->children[0], "table_name"), + error); + statistics_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[1], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(statistics_schema->children[1], "column_name"), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[2], NANOARROW_TYPE_INT16), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(statistics_schema->children[2], "statistic_key"), error); + statistics_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; + CHECK_NA(INTERNAL, + ArrowSchemaSetTypeUnion(statistics_schema->children[3], + NANOARROW_TYPE_DENSE_UNION, 4), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(statistics_schema->children[3], "statistic_value"), + error); + statistics_schema->children[3]->flags &= ~ARROW_FLAG_NULLABLE; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[4], NANOARROW_TYPE_BOOL), + error); + CHECK_NA( + INTERNAL, + ArrowSchemaSetName(statistics_schema->children[4], "statistic_is_approximate"), + error); + statistics_schema->children[4]->flags &= ~ARROW_FLAG_NULLABLE; + + struct ArrowSchema* value_schema = statistics_schema->children[3]; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[0], NANOARROW_TYPE_INT64), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[0], "int64"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[1], NANOARROW_TYPE_UINT64), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[1], "uint64"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[2], NANOARROW_TYPE_DOUBLE), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[2], "float64"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[3], NANOARROW_TYPE_BINARY), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[3], "binary"), error); + } + + // Set up builders + struct ArrowError na_error = {0}; + CHECK_NA_DETAIL(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), &na_error), + &na_error, error); + CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error); + + struct ArrowArray* catalog_name_col = array->children[0]; + struct ArrowArray* catalog_db_schemas_col = array->children[1]; + struct ArrowArray* catalog_db_schemas_items = catalog_db_schemas_col->children[0]; + struct ArrowArray* db_schema_name_col = catalog_db_schemas_items->children[0]; + struct ArrowArray* db_schema_statistics_col = catalog_db_schemas_items->children[1]; + struct ArrowArray* db_schema_statistics_items = db_schema_statistics_col->children[0]; + struct ArrowArray* statistics_table_name_col = db_schema_statistics_items->children[0]; + struct ArrowArray* statistics_column_name_col = db_schema_statistics_items->children[1]; + struct ArrowArray* statistics_key_col = db_schema_statistics_items->children[2]; + struct ArrowArray* statistics_value_col = db_schema_statistics_items->children[3]; + struct ArrowArray* statistics_is_approximate_col = + db_schema_statistics_items->children[4]; + // struct ArrowArray* value_int64_col = statistics_value_col->children[0]; + // struct ArrowArray* value_uint64_col = statistics_value_col->children[1]; + struct ArrowArray* value_float64_col = statistics_value_col->children[2]; + // struct ArrowArray* value_binary_col = statistics_value_col->children[3]; + + // Query (could probably be massively improved) + std::string query = R"( + WITH + class AS ( + SELECT nspname, relname, reltuples + FROM pg_namespace + INNER JOIN pg_class ON pg_class.relnamespace = pg_namespace.oid + ) + SELECT tablename, attname, null_frac, avg_width, n_distinct, reltuples + FROM pg_stats + INNER JOIN class ON pg_stats.schemaname = class.nspname AND pg_stats.tablename = class.relname + WHERE pg_stats.schemaname = $1 AND tablename LIKE $2 + ORDER BY tablename +)"; + + CHECK_NA(INTERNAL, ArrowArrayAppendString(catalog_name_col, ArrowCharView(PQdb(conn))), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendString(db_schema_name_col, ArrowCharView(db_schema)), + error); + + constexpr int8_t kStatsVariantFloat64 = 2; + + std::string prev_table; + + { + PqResultHelper result_helper{ + conn, query, {db_schema, table_name ? table_name : "%"}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + + for (PqResultRow row : result_helper) { + auto reltuples = row[5].ParseDouble(); + if (!reltuples.first) { + SetError(error, "[libpq] Invalid double value in reltuples: '%s'", row[5].data); + return ADBC_STATUS_INTERNAL; + } + + if (std::strcmp(prev_table.c_str(), row[0].data) != 0) { + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendNull(statistics_column_name_col, 1), error); + CHECK_NA(INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_ROW_COUNT_KEY), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendDouble(value_float64_col, reltuples.second), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + prev_table = std::string(row[0].data, row[0].len); + } + + auto null_frac = row[2].ParseDouble(); + if (!null_frac.first) { + SetError(error, "[libpq] Invalid double value in null_frac: '%s'", row[2].data); + return ADBC_STATUS_INTERNAL; + } + + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_column_name_col, + ArrowStringView{row[1].data, row[1].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_NULL_COUNT_KEY), + error); + CHECK_NA( + INTERNAL, + ArrowArrayAppendDouble(value_float64_col, null_frac.second * reltuples.second), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + + auto average_byte_width = row[3].ParseDouble(); + if (!average_byte_width.first) { + SetError(error, "[libpq] Invalid double value in avg_width: '%s'", row[3].data); + return ADBC_STATUS_INTERNAL; + } + + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_column_name_col, + ArrowStringView{row[1].data, row[1].len}), + error); + CHECK_NA( + INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendDouble(value_float64_col, average_byte_width.second), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + + auto n_distinct = row[4].ParseDouble(); + if (!n_distinct.first) { + SetError(error, "[libpq] Invalid double value in avg_width: '%s'", row[4].data); + return ADBC_STATUS_INTERNAL; + } + + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_column_name_col, + ArrowStringView{row[1].data, row[1].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_DISTINCT_COUNT_KEY), + error); + // > If greater than zero, the estimated number of distinct values in + // > the column. If less than zero, the negative of the number of + // > distinct values divided by the number of rows. + // https://www.postgresql.org/docs/current/view-pg-stats.html + CHECK_NA( + INTERNAL, + ArrowArrayAppendDouble(value_float64_col, + n_distinct.second > 0 + ? n_distinct.second + : (std::fabs(n_distinct.second) * reltuples.second)), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + } + } + + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_col), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_items), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_col), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error); + + CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array, &na_error), &na_error, + error); + uschema.move(schema); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PostgresConnection::GetStatistics(const char* catalog, + const char* db_schema, + const char* table_name, bool approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + // Simplify our jobs here + if (!approximate) { + SetError(error, "[libpq] Exact statistics are not implemented"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } else if (!db_schema) { + SetError(error, "[libpq] Must request statistics for a single schema"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } else if (catalog && std::strcmp(catalog, PQdb(conn_)) != 0) { + SetError(error, "[libpq] Can only request statistics for current catalog"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + struct ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + struct ArrowArray array; + std::memset(&array, 0, sizeof(array)); + + AdbcStatusCode status = PostgresConnectionGetStatisticsImpl( + conn_, db_schema, table_name, &schema, &array, error); + if (status != ADBC_STATUS_OK) { + if (schema.release) schema.release(&schema); + if (array.release) array.release(&array); + return status; + } + + return BatchToArrayStream(&array, &schema, out, error); +} + +AdbcStatusCode PostgresConnectionGetStatisticNamesImpl(struct ArrowSchema* schema, + struct ArrowArray* array, + struct AdbcError* error) { + auto uschema = nanoarrow::UniqueSchema(); + ArrowSchemaInit(uschema.get()); + + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get(), NANOARROW_TYPE_STRUCT), error); + CHECK_NA(INTERNAL, ArrowSchemaAllocateChildren(uschema.get(), /*num_columns=*/2), + error); + + ArrowSchemaInit(uschema.get()->children[0]); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(uschema.get()->children[0], NANOARROW_TYPE_STRING), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema.get()->children[0], "statistic_name"), + error); + uschema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + + ArrowSchemaInit(uschema.get()->children[1]); + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get()->children[1], NANOARROW_TYPE_INT16), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema.get()->children[1], "statistic_key"), + error); + uschema.get()->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + + CHECK_NA(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), NULL), error); + CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error); + CHECK_NA(INTERNAL, ArrowArrayFinishBuildingDefault(array, NULL), error); + + uschema.move(schema); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PostgresConnection::GetStatisticNames(struct ArrowArrayStream* out, + struct AdbcError* error) { + // We don't support any extended statistics, just return an empty stream + struct ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + struct ArrowArray array; + std::memset(&array, 0, sizeof(array)); + + AdbcStatusCode status = PostgresConnectionGetStatisticNamesImpl(&schema, &array, error); + if (status != ADBC_STATUS_OK) { + if (schema.release) schema.release(&schema); + if (array.release) array.release(&array); + return status; + } + return BatchToArrayStream(&array, &schema, out, error); + + return ADBC_STATUS_OK; +} + AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, const char* db_schema, const char* table_name, @@ -964,16 +1392,26 @@ AdbcStatusCode PostgresConnection::GetTableTypes(struct AdbcConnection* connecti AdbcStatusCode PostgresConnection::Init(struct AdbcDatabase* database, struct AdbcError* error) { if (!database || !database->private_data) { - SetError(error, "%s", "[libpq] Must provide an initialized AdbcDatabase"); + SetError(error, "[libpq] Must provide an initialized AdbcDatabase"); return ADBC_STATUS_INVALID_ARGUMENT; } database_ = *reinterpret_cast*>(database->private_data); type_resolver_ = database_->type_resolver(); - return database_->Connect(&conn_, error); + RAISE_ADBC(database_->Connect(&conn_, error)); + cancel_ = PQgetCancel(conn_); + if (!cancel_) { + SetError(error, "[libpq] Could not initialize PGcancel"); + return ADBC_STATUS_UNKNOWN; + } + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::Release(struct AdbcError* error) { + if (cancel_) { + PQfreeCancel(cancel_); + cancel_ = nullptr; + } if (conn_) { return database_->Disconnect(&conn_, error); } @@ -1023,8 +1461,35 @@ AdbcStatusCode PostgresConnection::SetOption(const char* key, const char* value, autocommit_ = autocommit; } return ADBC_STATUS_OK; + } else if (std::strcmp(key, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { + // PostgreSQL doesn't accept a parameter here + PqResultHelper result_helper{ + conn_, std::string("SET search_path TO ") + value, {}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + return ADBC_STATUS_OK; } SetError(error, "%s%s", "[libpq] Unknown option ", key); return ADBC_STATUS_NOT_IMPLEMENTED; } + +AdbcStatusCode PostgresConnection::SetOptionBytes(const char* key, const uint8_t* value, + size_t length, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresConnection::SetOptionDouble(const char* key, double value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresConnection::SetOptionInt(const char* key, int64_t value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + } // namespace adbcpq diff --git a/c/driver/postgresql/connection.h b/c/driver/postgresql/connection.h index 74315ee053..50d3ebe322 100644 --- a/c/driver/postgresql/connection.h +++ b/c/driver/postgresql/connection.h @@ -29,10 +29,12 @@ namespace adbcpq { class PostgresDatabase; class PostgresConnection { public: - PostgresConnection() : database_(nullptr), conn_(nullptr), autocommit_(true) {} + PostgresConnection() + : database_(nullptr), conn_(nullptr), cancel_(nullptr), autocommit_(true) {} + AdbcStatusCode Cancel(struct AdbcError* error); AdbcStatusCode Commit(struct AdbcError* error); - AdbcStatusCode GetInfo(struct AdbcConnection* connection, uint32_t* info_codes, + AdbcStatusCode GetInfo(struct AdbcConnection* connection, const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error); AdbcStatusCode GetObjects(struct AdbcConnection* connection, int depth, @@ -40,6 +42,18 @@ class PostgresConnection { const char* table_name, const char** table_types, const char* column_name, struct ArrowArrayStream* out, struct AdbcError* error); + AdbcStatusCode GetOption(const char* option, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* option, double* value, + struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error); + AdbcStatusCode GetStatistics(const char* catalog, const char* db_schema, + const char* table_name, bool approximate, + struct ArrowArrayStream* out, struct AdbcError* error); + AdbcStatusCode GetStatisticNames(struct ArrowArrayStream* out, struct AdbcError* error); AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error); @@ -49,6 +63,10 @@ class PostgresConnection { AdbcStatusCode Release(struct AdbcError* error); AdbcStatusCode Rollback(struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); PGconn* conn() const { return conn_; } const std::shared_ptr& type_resolver() const { @@ -60,6 +78,7 @@ class PostgresConnection { std::shared_ptr database_; std::shared_ptr type_resolver_; PGconn* conn_; + PGcancel* cancel_; bool autocommit_; }; } // namespace adbcpq diff --git a/c/driver/postgresql/database.cc b/c/driver/postgresql/database.cc index 3976c4b08d..5de8628095 100644 --- a/c/driver/postgresql/database.cc +++ b/c/driver/postgresql/database.cc @@ -36,6 +36,23 @@ PostgresDatabase::PostgresDatabase() : open_connections_(0) { } PostgresDatabase::~PostgresDatabase() = default; +AdbcStatusCode PostgresDatabase::GetOption(const char* option, char* value, + size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresDatabase::GetOptionBytes(const char* option, uint8_t* value, + size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresDatabase::GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresDatabase::GetOptionDouble(const char* option, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode PostgresDatabase::Init(struct AdbcError* error) { // Connect to validate the parameters. return RebuildTypeResolver(error); @@ -61,6 +78,24 @@ AdbcStatusCode PostgresDatabase::SetOption(const char* key, const char* value, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresDatabase::SetOptionBytes(const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresDatabase::SetOptionDouble(const char* key, double value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresDatabase::SetOptionInt(const char* key, int64_t value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode PostgresDatabase::Connect(PGconn** conn, struct AdbcError* error) { if (uri_.empty()) { SetError(error, "%s", @@ -90,10 +125,10 @@ AdbcStatusCode PostgresDatabase::Disconnect(PGconn** conn, struct AdbcError* err // Helpers for building the type resolver from queries static inline int32_t InsertPgAttributeResult( - pg_result* result, const std::shared_ptr& resolver); + PGresult* result, const std::shared_ptr& resolver); static inline int32_t InsertPgTypeResult( - pg_result* result, const std::shared_ptr& resolver); + PGresult* result, const std::shared_ptr& resolver); AdbcStatusCode PostgresDatabase::RebuildTypeResolver(struct AdbcError* error) { PGconn* conn = nullptr; @@ -142,7 +177,7 @@ ORDER BY auto resolver = std::make_shared(); // Insert record type definitions (this includes table schemas) - pg_result* result = PQexec(conn, kColumnsQuery.c_str()); + PGresult* result = PQexec(conn, kColumnsQuery.c_str()); ExecStatusType pq_status = PQresultStatus(result); if (pq_status == PGRES_TUPLES_OK) { InsertPgAttributeResult(result, resolver); @@ -187,7 +222,7 @@ ORDER BY } static inline int32_t InsertPgAttributeResult( - pg_result* result, const std::shared_ptr& resolver) { + PGresult* result, const std::shared_ptr& resolver) { int num_rows = PQntuples(result); std::vector> columns; uint32_t current_type_oid = 0; @@ -219,7 +254,7 @@ static inline int32_t InsertPgAttributeResult( } static inline int32_t InsertPgTypeResult( - pg_result* result, const std::shared_ptr& resolver) { + PGresult* result, const std::shared_ptr& resolver) { int num_rows = PQntuples(result); PostgresTypeResolver::Item item; int32_t n_added = 0; diff --git a/c/driver/postgresql/database.h b/c/driver/postgresql/database.h index f10464787a..6c3da58daa 100644 --- a/c/driver/postgresql/database.h +++ b/c/driver/postgresql/database.h @@ -36,7 +36,19 @@ class PostgresDatabase { AdbcStatusCode Init(struct AdbcError* error); AdbcStatusCode Release(struct AdbcError* error); + AdbcStatusCode GetOption(const char* option, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* option, double* value, + struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); // Internal implementation @@ -54,3 +66,10 @@ class PostgresDatabase { std::shared_ptr type_resolver_; }; } // namespace adbcpq + +extern "C" { +/// For applications that want to use the driver struct directly, this gives +/// them access to the Init routine. +ADBC_EXPORT +AdbcStatusCode PostgresqlDriverInit(int, void*, struct AdbcError*); +} diff --git a/c/driver/postgresql/error.cc b/c/driver/postgresql/error.cc new file mode 100644 index 0000000000..47e04496ba --- /dev/null +++ b/c/driver/postgresql/error.cc @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "error.h" + +#include +#include +#include +#include +#include + +#include + +#include "common/utils.h" + +namespace adbcpq { + +namespace { +struct DetailField { + int code; + std::string key; +}; + +static const std::vector kDetailFields = { + {PG_DIAG_COLUMN_NAME, "PG_DIAG_COLUMN_NAME"}, + {PG_DIAG_CONTEXT, "PG_DIAG_CONTEXT"}, + {PG_DIAG_CONSTRAINT_NAME, "PG_DIAG_CONSTRAINT_NAME"}, + {PG_DIAG_DATATYPE_NAME, "PG_DIAG_DATATYPE_NAME"}, + {PG_DIAG_INTERNAL_POSITION, "PG_DIAG_INTERNAL_POSITION"}, + {PG_DIAG_INTERNAL_QUERY, "PG_DIAG_INTERNAL_QUERY"}, + {PG_DIAG_MESSAGE_PRIMARY, "PG_DIAG_MESSAGE_PRIMARY"}, + {PG_DIAG_MESSAGE_DETAIL, "PG_DIAG_MESSAGE_DETAIL"}, + {PG_DIAG_MESSAGE_HINT, "PG_DIAG_MESSAGE_HINT"}, + {PG_DIAG_SEVERITY_NONLOCALIZED, "PG_DIAG_SEVERITY_NONLOCALIZED"}, + {PG_DIAG_SQLSTATE, "PG_DIAG_SQLSTATE"}, + {PG_DIAG_STATEMENT_POSITION, "PG_DIAG_STATEMENT_POSITION"}, + {PG_DIAG_SCHEMA_NAME, "PG_DIAG_SCHEMA_NAME"}, + {PG_DIAG_TABLE_NAME, "PG_DIAG_TABLE_NAME"}, +}; +} // namespace + +AdbcStatusCode SetError(struct AdbcError* error, PGresult* result, const char* format, + ...) { + va_list args; + va_start(args, format); + SetErrorVariadic(error, format, args); + va_end(args); + + AdbcStatusCode code = ADBC_STATUS_IO; + + const char* sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE); + if (sqlstate) { + // https://www.postgresql.org/docs/current/errcodes-appendix.html + // This can be extended in the future + if (std::strcmp(sqlstate, "57014") == 0) { + code = ADBC_STATUS_CANCELLED; + } else if (std::strncmp(sqlstate, "42", 0) == 0) { + // Class 42 — Syntax Error or Access Rule Violation + code = ADBC_STATUS_INVALID_ARGUMENT; + } + + static_assert(sizeof(error->sqlstate) == 5, ""); + // N.B. strncpy generates warnings when used for this purpose + int i = 0; + for (; sqlstate[i] != '\0' && i < 5; i++) { + error->sqlstate[i] = sqlstate[i]; + } + for (; i < 5; i++) { + error->sqlstate[i] = '\0'; + } + } + + for (const auto& field : kDetailFields) { + const char* value = PQresultErrorField(result, field.code); + if (value) { + AppendErrorDetail(error, field.key.c_str(), reinterpret_cast(value), + std::strlen(value)); + } + } + return code; +} + +} // namespace adbcpq diff --git a/c/driver/postgresql/error.h b/c/driver/postgresql/error.h new file mode 100644 index 0000000000..75c52b46c3 --- /dev/null +++ b/c/driver/postgresql/error.h @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Error handling utilities. + +#pragma once + +#include +#include + +namespace adbcpq { + +// The printf checking attribute doesn't work properly on gcc 4.8 +// and results in spurious compiler warnings +#if defined(__clang__) || (defined(__GNUC__) && __GNUC__ >= 5) +#define ADBC_CHECK_PRINTF_ATTRIBUTE(x, y) __attribute__((format(printf, x, y))) +#else +#define ADBC_CHECK_PRINTF_ATTRIBUTE(x, y) +#endif + +/// \brief Set an error based on a PGresult, inferring the proper ADBC status +/// code from the PGresult. +AdbcStatusCode SetError(struct AdbcError* error, PGresult* result, const char* format, + ...) ADBC_CHECK_PRINTF_ATTRIBUTE(3, 4); + +#undef ADBC_CHECK_PRINTF_ATTRIBUTE + +} // namespace adbcpq diff --git a/c/driver/postgresql/postgres_copy_reader.h b/c/driver/postgresql/postgres_copy_reader.h index 4aa5a82e69..5c7214dc04 100644 --- a/c/driver/postgresql/postgres_copy_reader.h +++ b/c/driver/postgresql/postgres_copy_reader.h @@ -893,12 +893,13 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type, class PostgresCopyStreamReader { public: - ArrowErrorCode Init(const PostgresType& pg_type) { + ArrowErrorCode Init(PostgresType pg_type) { if (pg_type.type_id() != PostgresTypeId::kRecord) { return EINVAL; } - root_reader_.Init(pg_type); + pg_type_ = std::move(pg_type); + root_reader_.Init(pg_type_); array_size_approx_bytes_ = 0; return NANOARROW_OK; } @@ -1022,7 +1023,10 @@ class PostgresCopyStreamReader { return NANOARROW_OK; } + const PostgresType& pg_type() const { return pg_type_; } + private: + PostgresType pg_type_; PostgresCopyFieldTupleReader root_reader_; nanoarrow::UniqueSchema schema_; nanoarrow::UniqueArray array_; diff --git a/c/driver/postgresql/postgresql.cc b/c/driver/postgresql/postgresql.cc index 29fd04cddc..2e25c4bedf 100644 --- a/c/driver/postgresql/postgresql.cc +++ b/c/driver/postgresql/postgresql.cc @@ -34,7 +34,7 @@ using adbcpq::PostgresStatement; // --------------------------------------------------------------------- // ADBC interface implementation - as private functions so that these // don't get replaced by the dynamic linker. If we implemented these -// under the Adbc* names, then DriverInit, the linker may resolve +// under the Adbc* names, then in DriverInit, the linker may resolve // functions to the address of the functions provided by the driver // manager instead of our functions. // @@ -47,6 +47,30 @@ using adbcpq::PostgresStatement; // // So in the end some manual effort here was chosen. +// --------------------------------------------------------------------- +// AdbcError + +namespace { +const struct AdbcError* PostgresErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + // Currently only valid for TupleReader + return adbcpq::TupleReader::ErrorFromArrayStream(stream, status); +} +} // namespace + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return CommonErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return CommonErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return PostgresErrorFromArrayStream(stream, status); +} + // --------------------------------------------------------------------- // AdbcDatabase @@ -83,14 +107,92 @@ AdbcStatusCode PostgresDatabaseRelease(struct AdbcDatabase* database, return status; } +AdbcStatusCode PostgresDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOption(key, value, length, error); +} + +AdbcStatusCode PostgresDatabaseGetOptionBytes(struct AdbcDatabase* database, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresDatabaseGetOptionDouble(struct AdbcDatabase* database, + const char* key, double* value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresDatabaseGetOptionInt(struct AdbcDatabase* database, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOptionInt(key, value, error); +} + AdbcStatusCode PostgresDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { if (!database || !database->private_data) return ADBC_STATUS_INVALID_STATE; auto ptr = reinterpret_cast*>(database->private_data); return (*ptr)->SetOption(key, value, error); } + +AdbcStatusCode PostgresDatabaseSetOptionBytes(struct AdbcDatabase* database, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->SetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresDatabaseSetOptionDouble(struct AdbcDatabase* database, + const char* key, double value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->SetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresDatabaseSetOptionInt(struct AdbcDatabase* database, + const char* key, int64_t value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->SetOptionInt(key, value, error); +} } // namespace +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PostgresDatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return PostgresDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return PostgresDatabaseGetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return PostgresDatabaseGetOptionDouble(database, key, value, error); +} + AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return PostgresDatabaseInit(database, error); } @@ -109,10 +211,34 @@ AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* return PostgresDatabaseSetOption(database, key, value, error); } +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return PostgresDatabaseSetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return PostgresDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return PostgresDatabaseSetOptionDouble(database, key, value, error); +} + // --------------------------------------------------------------------- // AdbcConnection namespace { +AdbcStatusCode PostgresConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->Cancel(error); +} + AdbcStatusCode PostgresConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; @@ -122,7 +248,8 @@ AdbcStatusCode PostgresConnectionCommit(struct AdbcConnection* connection, } AdbcStatusCode PostgresConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; @@ -142,6 +269,63 @@ AdbcStatusCode PostgresConnectionGetObjects( table_types, column_name, stream, error); } +AdbcStatusCode PostgresConnectionGetOption(struct AdbcConnection* connection, + const char* key, char* value, size_t* length, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOption(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionInt(key, value, error); +} + +AdbcStatusCode PostgresConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetStatistics(catalog, db_schema, table_name, approximate == 1, out, + error); +} + +AdbcStatusCode PostgresConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetStatisticNames(out, error); +} + AdbcStatusCode PostgresConnectionGetTableSchema( struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) { @@ -213,14 +397,47 @@ AdbcStatusCode PostgresConnectionSetOption(struct AdbcConnection* connection, return (*ptr)->SetOption(key, value, error); } +AdbcStatusCode PostgresConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionInt(key, value, error); +} + } // namespace + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return PostgresConnectionCancel(connection, error); +} + AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { return PostgresConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { return PostgresConnectionGetInfo(connection, info_codes, info_codes_length, stream, @@ -237,6 +454,45 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d table_types, column_name, stream, error); } +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PostgresConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PostgresConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return PostgresConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return PostgresConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return PostgresConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return PostgresConnectionGetStatisticNames(connection, out, error); +} + AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, @@ -287,6 +543,24 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const return PostgresConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PostgresConnectionSetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return PostgresConnectionSetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return PostgresConnectionSetOptionDouble(connection, key, value, error); +} + // --------------------------------------------------------------------- // AdbcStatement @@ -310,6 +584,14 @@ AdbcStatusCode PostgresStatementBindStream(struct AdbcStatement* statement, return (*ptr)->Bind(stream, error); } +AdbcStatusCode PostgresStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->Cancel(error); +} + AdbcStatusCode PostgresStatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -329,16 +611,49 @@ AdbcStatusCode PostgresStatementExecuteQuery(struct AdbcStatement* statement, return (*ptr)->ExecuteQuery(output, rows_affected, error); } -AdbcStatusCode PostgresStatementGetPartitionDesc(struct AdbcStatement* statement, - uint8_t* partition_desc, - struct AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; +AdbcStatusCode PostgresStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->ExecuteSchema(schema, error); } -AdbcStatusCode PostgresStatementGetPartitionDescSize(struct AdbcStatement* statement, - size_t* length, - struct AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; +AdbcStatusCode PostgresStatementGetOption(struct AdbcStatement* statement, + const char* key, char* value, size_t* length, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOption(key, value, length, error); +} + +AdbcStatusCode PostgresStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOptionInt(key, value, error); } AdbcStatusCode PostgresStatementGetParameterSchema(struct AdbcStatement* statement, @@ -386,6 +701,33 @@ AdbcStatusCode PostgresStatementSetOption(struct AdbcStatement* statement, return (*ptr)->SetOption(key, value, error); } +AdbcStatusCode PostgresStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetOptionInt(key, value, error); +} + AdbcStatusCode PostgresStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; @@ -407,6 +749,11 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return PostgresStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return PostgresStatementCancel(statement, error); +} + AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -423,16 +770,32 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, return PostgresStatementExecuteQuery(statement, output, rows_affected, error); } -AdbcStatusCode AdbcStatementGetPartitionDesc(struct AdbcStatement* statement, - uint8_t* partition_desc, - struct AdbcError* error) { - return PostgresStatementGetPartitionDesc(statement, partition_desc, error); +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + ArrowSchema* schema, struct AdbcError* error) { + return PostgresStatementExecuteSchema(statement, schema, error); } -AdbcStatusCode AdbcStatementGetPartitionDescSize(struct AdbcStatement* statement, - size_t* length, - struct AdbcError* error) { - return PostgresStatementGetPartitionDescSize(statement, length, error); +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PostgresStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PostgresStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return PostgresStatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return PostgresStatementGetOptionDouble(statement, key, value, error); } AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, @@ -462,6 +825,23 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha return PostgresStatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PostgresStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return PostgresStatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return PostgresStatementSetOptionDouble(statement, key, value, error); +} + AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { return PostgresStatementSetSqlQuery(statement, query, error); @@ -469,11 +849,53 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, extern "C" { ADBC_EXPORT -AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error) { - if (version != ADBC_VERSION_1_0_0) return ADBC_STATUS_NOT_IMPLEMENTED; +AdbcStatusCode PostgresqlDriverInit(int version, void* raw_driver, + struct AdbcError* error) { + if (version != ADBC_VERSION_1_0_0 && version != ADBC_VERSION_1_1_0) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + if (!raw_driver) return ADBC_STATUS_INVALID_ARGUMENT; auto* driver = reinterpret_cast(raw_driver); - std::memset(driver, 0, sizeof(*driver)); + if (version >= ADBC_VERSION_1_1_0) { + std::memset(driver, 0, ADBC_DRIVER_1_1_0_SIZE); + + driver->ErrorGetDetailCount = CommonErrorGetDetailCount; + driver->ErrorGetDetail = CommonErrorGetDetail; + driver->ErrorFromArrayStream = PostgresErrorFromArrayStream; + + driver->DatabaseGetOption = PostgresDatabaseGetOption; + driver->DatabaseGetOptionBytes = PostgresDatabaseGetOptionBytes; + driver->DatabaseGetOptionDouble = PostgresDatabaseGetOptionDouble; + driver->DatabaseGetOptionInt = PostgresDatabaseGetOptionInt; + driver->DatabaseSetOptionBytes = PostgresDatabaseSetOptionBytes; + driver->DatabaseSetOptionDouble = PostgresDatabaseSetOptionDouble; + driver->DatabaseSetOptionInt = PostgresDatabaseSetOptionInt; + + driver->ConnectionCancel = PostgresConnectionCancel; + driver->ConnectionGetOption = PostgresConnectionGetOption; + driver->ConnectionGetOptionBytes = PostgresConnectionGetOptionBytes; + driver->ConnectionGetOptionDouble = PostgresConnectionGetOptionDouble; + driver->ConnectionGetOptionInt = PostgresConnectionGetOptionInt; + driver->ConnectionGetStatistics = PostgresConnectionGetStatistics; + driver->ConnectionGetStatisticNames = PostgresConnectionGetStatisticNames; + driver->ConnectionSetOptionBytes = PostgresConnectionSetOptionBytes; + driver->ConnectionSetOptionDouble = PostgresConnectionSetOptionDouble; + driver->ConnectionSetOptionInt = PostgresConnectionSetOptionInt; + + driver->StatementCancel = PostgresStatementCancel; + driver->StatementExecuteSchema = PostgresStatementExecuteSchema; + driver->StatementGetOption = PostgresStatementGetOption; + driver->StatementGetOptionBytes = PostgresStatementGetOptionBytes; + driver->StatementGetOptionDouble = PostgresStatementGetOptionDouble; + driver->StatementGetOptionInt = PostgresStatementGetOptionInt; + driver->StatementSetOptionBytes = PostgresStatementSetOptionBytes; + driver->StatementSetOptionDouble = PostgresStatementSetOptionDouble; + driver->StatementSetOptionInt = PostgresStatementSetOptionInt; + } else { + std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + } + driver->DatabaseInit = PostgresDatabaseInit; driver->DatabaseNew = PostgresDatabaseNew; driver->DatabaseRelease = PostgresDatabaseRelease; @@ -501,6 +923,12 @@ AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* e driver->StatementRelease = PostgresStatementRelease; driver->StatementSetOption = PostgresStatementSetOption; driver->StatementSetSqlQuery = PostgresStatementSetSqlQuery; + return ADBC_STATUS_OK; } + +ADBC_EXPORT +AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error) { + return PostgresqlDriverInit(version, raw_driver, error); +} } diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index a826e17267..84a264f3c8 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include +#include #include #include #include @@ -25,8 +27,9 @@ #include #include #include -#include "common/utils.h" +#include "common/utils.h" +#include "database.h" #include "validation/adbc_validation.h" #include "validation/adbc_validation_util.h" @@ -103,6 +106,30 @@ class PostgresQuirks : public adbc_validation::DriverQuirks { std::string catalog() const override { return "postgres"; } std::string db_schema() const override { return "public"; } + + bool supports_cancel() const override { return true; } + bool supports_execute_schema() const override { return true; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_ADBC_VERSION: + return ADBC_VERSION_1_1_0; + case ADBC_INFO_DRIVER_NAME: + return "ADBC PostgreSQL Driver"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown)"; + case ADBC_INFO_VENDOR_NAME: + return "PostgreSQL"; + case ADBC_INFO_VENDOR_VERSION: + // Strings are checked via substring match + return "15"; + default: + return std::nullopt; + } + } + bool supports_metadata_current_catalog() const override { return true; } + bool supports_metadata_current_db_schema() const override { return true; } + bool supports_statistics() const override { return true; } }; class PostgresDatabaseTest : public ::testing::Test, @@ -117,6 +144,20 @@ class PostgresDatabaseTest : public ::testing::Test, }; ADBCV_TEST_DATABASE(PostgresDatabaseTest) +TEST_F(PostgresDatabaseTest, AdbcDriverBackwardsCompatibility) { + // XXX: sketchy cast + auto* driver = static_cast(malloc(ADBC_DRIVER_1_0_0_SIZE)); + std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + + ASSERT_THAT(::PostgresqlDriverInit(ADBC_VERSION_1_0_0, driver, &error), + IsOkStatus(&error)); + + ASSERT_THAT(::PostgresqlDriverInit(424242, driver, &error), + IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &error)); + + free(driver); +} + class PostgresConnectionTest : public ::testing::Test, public adbc_validation::ConnectionTest { public: @@ -134,10 +175,8 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { adbc_validation::StreamReader reader; std::vector info = { - ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, - ADBC_INFO_VENDOR_NAME, - ADBC_INFO_VENDOR_VERSION, + ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION, + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, }; ASSERT_THAT(AdbcConnectionGetInfo(&connection, info.data(), info.size(), &reader.stream.value, &error), @@ -153,29 +192,30 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row)); const uint32_t code = reader.array_view->children[0]->buffer_views[1].data.as_uint32[row]; + const uint32_t offset = + reader.array_view->children[1]->buffer_views[1].data.as_int32[row]; seen.push_back(code); - int str_child_index = 0; - struct ArrowArrayView* str_child = - reader.array_view->children[1]->children[str_child_index]; + struct ArrowArrayView* str_child = reader.array_view->children[1]->children[0]; + struct ArrowArrayView* int_child = reader.array_view->children[1]->children[2]; switch (code) { case ADBC_INFO_DRIVER_NAME: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 0); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("ADBC PostgreSQL Driver", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_DRIVER_VERSION: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 1); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("(unknown)", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_VENDOR_NAME: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 2); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("PostgreSQL", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_VENDOR_VERSION: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 3); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); #ifdef __WIN32 const char* pater = "\\d\\d\\d\\d\\d\\d"; #else @@ -185,6 +225,10 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { ::testing::MatchesRegex(pater)); break; } + case ADBC_INFO_DRIVER_ADBC_VERSION: { + EXPECT_EQ(ADBC_VERSION_1_1_0, ArrowArrayViewGetIntUnsafe(int_child, offset)); + break; + } default: // Ignored break; @@ -198,10 +242,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetCatalogs) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - adbc_validation::StreamReader reader; ASSERT_THAT( AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_CATALOGS, nullptr, nullptr, @@ -228,10 +268,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetDbSchemas) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - adbc_validation::StreamReader reader; ASSERT_THAT(AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_DB_SCHEMAS, nullptr, nullptr, nullptr, nullptr, nullptr, @@ -255,10 +291,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsPrimaryKey) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - ASSERT_THAT(quirks()->DropTable(&connection, "adbc_pkey_test", &error), IsOkStatus(&error)); @@ -329,10 +361,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsForeignKey) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - ASSERT_THAT(quirks()->DropTable(&connection, "adbc_fkey_test", &error), IsOkStatus(&error)); ASSERT_THAT(quirks()->DropTable(&connection, "adbc_fkey_test_base", &error), @@ -450,10 +478,6 @@ TEST_F(PostgresConnectionTest, GetObjectsTableTypesFilter) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - ASSERT_THAT(quirks()->DropView(&connection, "adbc_table_types_view_test", &error), IsOkStatus(&error)); ASSERT_THAT(quirks()->DropTable(&connection, "adbc_table_types_table_test", &error), @@ -516,7 +540,7 @@ TEST_F(PostgresConnectionTest, GetObjectsTableTypesFilter) { } TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); @@ -531,13 +555,13 @@ TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) { /*db_schema=*/nullptr, "0'::int; DROP TABLE bulk_ingest;--", &schema.value, &error), - IsStatus(ADBC_STATUS_IO, &error)); + IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); ASSERT_THAT( AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, /*db_schema=*/"0'::int; DROP TABLE bulk_ingest;--", "DROP TABLE bulk_ingest;", &schema.value, &error), - IsStatus(ADBC_STATUS_IO, &error)); + IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, /*db_schema=*/nullptr, "bulk_ingest", @@ -549,6 +573,236 @@ TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) { {"strings", NANOARROW_TYPE_STRING, true}})); } +TEST_F(PostgresConnectionTest, MetadataSetCurrentDbSchema) { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement.value, "CREATE SCHEMA IF NOT EXISTS testschema", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement.value, + "CREATE TABLE IF NOT EXISTS testschema.schematable (ints INT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } + + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + // Table does not exist in this schema + error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement.value, "SELECT * FROM schematable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); + // 42P01 = table not found + ASSERT_EQ("42P01", std::string_view(error.sqlstate, 5)); + ASSERT_NE(0, AdbcErrorGetDetailCount(&error)); + bool found = false; + for (int i = 0; i < AdbcErrorGetDetailCount(&error); i++) { + struct AdbcErrorDetail detail = AdbcErrorGetDetail(&error, i); + if (std::strcmp(detail.key, "PG_DIAG_MESSAGE_PRIMARY") == 0) { + found = true; + std::string_view message(reinterpret_cast(detail.value), + detail.value_length); + ASSERT_THAT(message, ::testing::HasSubstr("schematable")); + } + } + error.release(&error); + ASSERT_TRUE(found) << "Did not find expected error detail"; + + ASSERT_THAT( + AdbcConnectionSetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, + "testschema", &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement.value, "SELECT * FROM schematable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); +} + +TEST_F(PostgresConnectionTest, MetadataGetStatistics) { + if (!quirks()->supports_statistics()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + // Create sample table + { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, + "DROP TABLE IF EXISTS statstable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement.value, + "CREATE TABLE statstable (ints INT, strs TEXT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement.value, + "INSERT INTO statstable VALUES (1, 'a'), (NULL, 'bcd'), (-5, NULL)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "ANALYZE statstable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } + + adbc_validation::StreamReader reader; + ASSERT_THAT( + AdbcConnectionGetStatistics(&connection, nullptr, quirks()->db_schema().c_str(), + "statstable", 1, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + &reader.schema.value, { + {"catalog_name", NANOARROW_TYPE_STRING, true}, + {"catalog_db_schemas", NANOARROW_TYPE_LIST, false}, + })); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + reader.schema->children[1]->children[0], + { + {"db_schema_name", NANOARROW_TYPE_STRING, true}, + {"db_schema_statistics", NANOARROW_TYPE_LIST, false}, + })); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + reader.schema->children[1]->children[0]->children[1]->children[0], + { + {"table_name", NANOARROW_TYPE_STRING, false}, + {"column_name", NANOARROW_TYPE_STRING, true}, + {"statistic_key", NANOARROW_TYPE_INT16, false}, + {"statistic_value", NANOARROW_TYPE_DENSE_UNION, false}, + {"statistic_is_approximate", NANOARROW_TYPE_BOOL, false}, + })); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + reader.schema->children[1]->children[0]->children[1]->children[0]->children[3], + { + {"int64", NANOARROW_TYPE_INT64, true}, + {"uint64", NANOARROW_TYPE_UINT64, true}, + {"float64", NANOARROW_TYPE_DOUBLE, true}, + {"binary", NANOARROW_TYPE_BINARY, true}, + })); + + std::vector, int16_t, int64_t>> seen; + while (true) { + ASSERT_NO_FATAL_FAILURE(reader.Next()); + if (!reader.array->release) break; + + for (int64_t catalog_index = 0; catalog_index < reader.array->length; + catalog_index++) { + struct ArrowStringView catalog_name = + ArrowArrayViewGetStringUnsafe(reader.array_view->children[0], catalog_index); + ASSERT_EQ(quirks()->catalog(), + std::string_view(catalog_name.data, + static_cast(catalog_name.size_bytes))); + + struct ArrowArrayView* catalog_db_schemas = reader.array_view->children[1]; + struct ArrowArrayView* schema_stats = catalog_db_schemas->children[0]->children[1]; + struct ArrowArrayView* stats = + catalog_db_schemas->children[0]->children[1]->children[0]; + for (int64_t schema_index = + ArrowArrayViewListChildOffset(catalog_db_schemas, catalog_index); + schema_index < + ArrowArrayViewListChildOffset(catalog_db_schemas, catalog_index + 1); + schema_index++) { + struct ArrowStringView schema_name = ArrowArrayViewGetStringUnsafe( + catalog_db_schemas->children[0]->children[0], schema_index); + ASSERT_EQ(quirks()->db_schema(), + std::string_view(schema_name.data, + static_cast(schema_name.size_bytes))); + + for (int64_t stat_index = + ArrowArrayViewListChildOffset(schema_stats, schema_index); + stat_index < ArrowArrayViewListChildOffset(schema_stats, schema_index + 1); + stat_index++) { + struct ArrowStringView table_name = + ArrowArrayViewGetStringUnsafe(stats->children[0], stat_index); + ASSERT_EQ("statstable", + std::string_view(table_name.data, + static_cast(table_name.size_bytes))); + std::optional column_name; + if (!ArrowArrayViewIsNull(stats->children[1], stat_index)) { + struct ArrowStringView value = + ArrowArrayViewGetStringUnsafe(stats->children[1], stat_index); + column_name = std::string(value.data, value.size_bytes); + } + ASSERT_TRUE(ArrowArrayViewGetIntUnsafe(stats->children[4], stat_index)); + + const int16_t stat_key = static_cast( + ArrowArrayViewGetIntUnsafe(stats->children[2], stat_index)); + const int32_t offset = + stats->children[3]->buffer_views[1].data.as_int32[stat_index]; + int64_t stat_value; + switch (stat_key) { + case ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY: + case ADBC_STATISTIC_DISTINCT_COUNT_KEY: + case ADBC_STATISTIC_NULL_COUNT_KEY: + case ADBC_STATISTIC_ROW_COUNT_KEY: + stat_value = static_cast( + std::round(100 * ArrowArrayViewGetDoubleUnsafe( + stats->children[3]->children[2], offset))); + break; + default: + continue; + } + seen.emplace_back(std::move(column_name), stat_key, stat_value); + } + } + } + } + + ASSERT_THAT(seen, + ::testing::UnorderedElementsAreArray( + std::vector, int16_t, int64_t>>{ + {"ints", ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY, 400}, + {"strs", ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY, 300}, + {"ints", ADBC_STATISTIC_NULL_COUNT_KEY, 100}, + {"strs", ADBC_STATISTIC_NULL_COUNT_KEY, 100}, + {"ints", ADBC_STATISTIC_DISTINCT_COUNT_KEY, 200}, + {"strs", ADBC_STATISTIC_DISTINCT_COUNT_KEY, 200}, + {std::nullopt, ADBC_STATISTIC_ROW_COUNT_KEY, 300}, + })); +} + ADBCV_TEST_CONNECTION(PostgresConnectionTest) class PostgresStatementTest : public ::testing::Test, @@ -704,6 +958,67 @@ TEST_F(PostgresStatementTest, BatchSizeHint) { } } +// Test that an ADBC 1.0.0-sized error still works +TEST_F(PostgresStatementTest, AdbcErrorBackwardsCompatibility) { + // XXX: sketchy cast + auto* error = static_cast(malloc(ADBC_ERROR_1_0_0_SIZE)); + std::memset(error, 0, ADBC_ERROR_1_0_0_SIZE); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, error), IsOkStatus(error)); + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "SELECT * FROM thistabledoesnotexist", error), + IsOkStatus(error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, error), + IsStatus(ADBC_STATUS_INVALID_ARGUMENT, error)); + + ASSERT_EQ("42P01", std::string_view(error->sqlstate, 5)); + ASSERT_EQ(0, AdbcErrorGetDetailCount(error)); + + error->release(error); + free(error); +} + +TEST_F(PostgresStatementTest, Cancel) { + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + for (const char* query : { + "DROP TABLE IF EXISTS test_cancel", + "CREATE TABLE test_cancel (ints INT)", + R"(INSERT INTO test_cancel (ints) + SELECT g :: INT FROM GENERATE_SERIES(1, 65536) temp(g))", + }) { + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM test_cancel", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementCancel(&statement, &error), IsOkStatus(&error)); + + int retcode = 0; + while (true) { + retcode = reader.MaybeNext(); + if (retcode != 0 || !reader.array->release) break; + } + + ASSERT_EQ(ECANCELED, retcode); + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* detail = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(ADBC_STATUS_CANCELLED, status); + ASSERT_EQ("57014", std::string_view(detail->sqlstate, 5)); + ASSERT_NE(0, AdbcErrorGetDetailCount(detail)); +} + struct TypeTestCase { std::string name; std::string sql_type; diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 0b8f1fc080..6ea52b53ea 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -33,6 +33,7 @@ #include "common/utils.h" #include "connection.h" +#include "error.h" #include "postgres_copy_reader.h" #include "postgres_type.h" #include "postgres_util.h" @@ -272,20 +273,23 @@ struct BindStream { if (autocommit) { PGresult* begin_result = PQexec(conn, "BEGIN"); if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to begin transaction for timezone data: %s", - PQerrorMessage(conn)); + AdbcStatusCode code = + SetError(error, begin_result, + "[libpq] Failed to begin transaction for timezone data: %s", + PQerrorMessage(conn)); PQclear(begin_result); - return ADBC_STATUS_IO; + return code; } PQclear(begin_result); } PGresult* get_tz_result = PQexec(conn, "SELECT current_setting('TIMEZONE')"); if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) { - SetError(error, "[libpq] Could not query current timezone: %s", - PQerrorMessage(conn)); + AdbcStatusCode code = SetError(error, get_tz_result, + "[libpq] Could not query current timezone: %s", + PQerrorMessage(conn)); PQclear(get_tz_result); - return ADBC_STATUS_IO; + return code; } tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0)); @@ -293,10 +297,11 @@ struct BindStream { PGresult* set_utc_result = PQexec(conn, "SET TIME ZONE 'UTC'"); if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to set time zone to UTC: %s", - PQerrorMessage(conn)); + AdbcStatusCode code = SetError(error, set_utc_result, + "[libpq] Failed to set time zone to UTC: %s", + PQerrorMessage(conn)); PQclear(set_utc_result); - return ADBC_STATUS_IO; + return code; } PQclear(set_utc_result); break; @@ -306,10 +311,11 @@ struct BindStream { PGresult* result = PQprepare(conn, /*stmtName=*/"", query.c_str(), /*nParams=*/bind_schema->n_children, param_types.data()); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(conn), query.c_str()); + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", + PQerrorMessage(conn), query.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); return ADBC_STATUS_OK; @@ -476,10 +482,11 @@ struct BindStream { /*resultFormat=*/0 /*text*/); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "%s%s", "[libpq] Failed to execute prepared statement: ", - PQerrorMessage(conn)); + AdbcStatusCode code = SetError( + error, result, "%s%s", + "[libpq] Failed to execute prepared statement: ", PQerrorMessage(conn)); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); @@ -490,18 +497,21 @@ struct BindStream { std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; PGresult* reset_tz_result = PQexec(conn, reset_query.c_str()); if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to reset time zone: %s", PQerrorMessage(conn)); + AdbcStatusCode code = + SetError(error, reset_tz_result, "[libpq] Failed to reset time zone: %s", + PQerrorMessage(conn)); PQclear(reset_tz_result); - return ADBC_STATUS_IO; + return code; } PQclear(reset_tz_result); PGresult* commit_result = PQexec(conn, "COMMIT"); if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to commit transaction: %s", - PQerrorMessage(conn)); + AdbcStatusCode code = + SetError(error, commit_result, "[libpq] Failed to commit transaction: %s", + PQerrorMessage(conn)); PQclear(commit_result); - return ADBC_STATUS_IO; + return code; } PQclear(commit_result); } @@ -516,12 +526,13 @@ int TupleReader::GetSchema(struct ArrowSchema* out) { int na_res = copy_reader_->GetSchema(out); if (out->release == nullptr) { - StringBuilderAppend(&error_builder_, - "[libpq] Result set was already consumed or freed"); - return EINVAL; + SetError(&error_, "[libpq] Result set was already consumed or freed"); + status_ = ADBC_STATUS_INVALID_STATE; + return AdbcStatusCodeToErrno(status_); } else if (na_res != NANOARROW_OK) { // e.g., Can't allocate memory - StringBuilderAppend(&error_builder_, "[libpq] Error copying schema"); + SetError(&error_, "[libpq] Error copying schema"); + status_ = ADBC_STATUS_INTERNAL; } return na_res; @@ -534,15 +545,16 @@ int TupleReader::InitQueryAndFetchFirst(struct ArrowError* error) { data_.data.as_char = pgbuf_; if (get_copy_res == -2) { - StringBuilderAppend(&error_builder_, "[libpq] Fetch header failed: %s", - PQerrorMessage(conn_)); - return EIO; + SetError(&error_, "[libpq] Fetch header failed: %s", PQerrorMessage(conn_)); + status_ = ADBC_STATUS_IO; + return AdbcStatusCodeToErrno(status_); } int na_res = copy_reader_->ReadHeader(&data_, error); if (na_res != NANOARROW_OK) { - StringBuilderAppend(&error_builder_, "[libpq] ReadHeader failed: %s", error->message); - return EIO; + SetError(&error_, "[libpq] ReadHeader failed: %s", error->message); + status_ = ADBC_STATUS_IO; + return AdbcStatusCodeToErrno(status_); } return NANOARROW_OK; @@ -553,9 +565,9 @@ int TupleReader::AppendRowAndFetchNext(struct ArrowError* error) { // call to PQgetCopyData()) int na_res = copy_reader_->ReadRecord(&data_, error); if (na_res != NANOARROW_OK && na_res != ENODATA) { - StringBuilderAppend(&error_builder_, - "[libpq] ReadRecord failed at row %" PRId64 ": %s", row_id_, - error->message); + SetError(&error_, "[libpq] ReadRecord failed at row %" PRId64 ": %s", row_id_, + error->message); + status_ = ADBC_STATUS_IO; return na_res; } @@ -569,10 +581,10 @@ int TupleReader::AppendRowAndFetchNext(struct ArrowError* error) { data_.data.as_char = pgbuf_; if (get_copy_res == -2) { - StringBuilderAppend(&error_builder_, - "[libpq] PQgetCopyData failed at row %" PRId64 ": %s", row_id_, - PQerrorMessage(conn_)); - return EIO; + SetError(&error_, "[libpq] PQgetCopyData failed at row %" PRId64 ": %s", row_id_, + PQerrorMessage(conn_)); + status_ = ADBC_STATUS_IO; + return AdbcStatusCodeToErrno(status_); } else if (get_copy_res == -1) { // Returned when COPY has finished successfully return ENODATA; @@ -594,8 +606,8 @@ int TupleReader::BuildOutput(struct ArrowArray* out, struct ArrowError* error) { int na_res = copy_reader_->GetArray(out, error); if (na_res != NANOARROW_OK) { - StringBuilderAppend(&error_builder_, "[libpq] Failed to build result array: %s", - error->message); + SetError(&error_, "[libpq] Failed to build result array: %s", error->message); + status_ = ADBC_STATUS_INTERNAL; return na_res; } @@ -641,16 +653,22 @@ int TupleReader::GetNext(struct ArrowArray* out) { // Check the server-side response result_ = PQgetResult(conn_); - const int pq_status = PQresultStatus(result_); + const ExecStatusType pq_status = PQresultStatus(result_); if (pq_status != PGRES_COMMAND_OK) { - StringBuilderAppend(&error_builder_, "[libpq] Query failed [%d]: %s", pq_status, - PQresultErrorMessage(result_)); + const char* sqlstate = PQresultErrorField(result_, PG_DIAG_SQLSTATE); + SetError(&error_, result_, "[libpq] Query failed [%s]: %s", PQresStatus(pq_status), + PQresultErrorMessage(result_)); if (tmp.release != nullptr) { tmp.release(&tmp); } - return EIO; + if (sqlstate != nullptr && std::strcmp(sqlstate, "57014") == 0) { + status_ = ADBC_STATUS_CANCELLED; + } else { + status_ = ADBC_STATUS_IO; + } + return AdbcStatusCodeToErrno(status_); } ArrowArrayMove(&tmp, out); @@ -658,7 +676,11 @@ int TupleReader::GetNext(struct ArrowArray* out) { } void TupleReader::Release() { - StringBuilderReset(&error_builder_); + if (error_.release) { + error_.release(&error_); + } + error_ = ADBC_ERROR_INIT; + status_ = ADBC_STATUS_OK; if (result_) { PQclear(result_); @@ -686,6 +708,19 @@ void TupleReader::ExportTo(struct ArrowArrayStream* stream) { stream->private_data = this; } +const struct AdbcError* TupleReader::ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + if (!stream->private_data || stream->release != &ReleaseTrampoline) { + return nullptr; + } + + TupleReader* reader = static_cast(stream->private_data); + if (status) { + *status = reader->status_; + } + return &reader->error_; +} + int TupleReader::GetSchemaTrampoline(struct ArrowArrayStream* self, struct ArrowSchema* out) { if (!self || !self->private_data) return EINVAL; @@ -767,11 +802,42 @@ AdbcStatusCode PostgresStatement::Bind(struct ArrowArrayStream* stream, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::Cancel(struct AdbcError* error) { + // Ultimately the same underlying PGconn + return connection_->Cancel(error); +} + AdbcStatusCode PostgresStatement::CreateBulkTable( const struct ArrowSchema& source_schema, const std::vector& source_schema_fields, struct AdbcError* error) { std::string create = "CREATE TABLE "; + switch (ingest_.mode) { + case IngestMode::kCreate: + // Nothing to do + break; + case IngestMode::kAppend: + return ADBC_STATUS_OK; + case IngestMode::kReplace: { + std::string drop = "DROP TABLE IF EXISTS " + ingest_.target; + PGresult* result = PQexecParams(connection_->conn(), drop.c_str(), /*nParams=*/0, + /*paramTypes=*/nullptr, /*paramValues=*/nullptr, + /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, + /*resultFormat=*/1 /*(binary)*/); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to drop table: %s\nQuery was: %s", + PQerrorMessage(connection_->conn()), drop.c_str()); + PQclear(result); + return code; + } + PQclear(result); + break; + } + case IngestMode::kCreateAppend: + create += "IF NOT EXISTS "; + break; + } create += ingest_.target; create += " ("; @@ -830,10 +896,11 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, /*resultFormat=*/1 /*(binary)*/); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to create table: %s\nQuery was: %s", - PQerrorMessage(connection_->conn()), create.c_str()); + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to create table: %s\nQuery was: %s", + PQerrorMessage(connection_->conn()), create.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); return ADBC_STATUS_OK; @@ -894,50 +961,12 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // 1. Prepare the query to get the schema { - // TODO: we should pipeline here and assume this will succeed - PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(), - /*nParams=*/0, nullptr); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, - "[libpq] Failed to execute query: could not infer schema: failed to " - "prepare query: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return ADBC_STATUS_IO; - } - PQclear(result); - result = PQdescribePrepared(connection_->conn(), /*stmtName=*/""); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, - "[libpq] Failed to execute query: could not infer schema: failed to " - "describe prepared statement: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return ADBC_STATUS_IO; - } - - // Resolve the information from the PGresult into a PostgresType - PostgresType root_type; - AdbcStatusCode status = - ResolvePostgresType(*type_resolver_, result, &root_type, error); - PQclear(result); - if (status != ADBC_STATUS_OK) return status; - - // Initialize the copy reader and infer the output schema (i.e., error for - // unsupported types before issuing the COPY query) - reader_.copy_reader_.reset(new PostgresCopyStreamReader()); - reader_.copy_reader_->Init(root_type); - struct ArrowError na_error; - int na_res = reader_.copy_reader_->InferOutputSchema(&na_error); - if (na_res != NANOARROW_OK) { - SetError(error, "[libpq] Failed to infer output schema: %s", na_error.message); - return na_res; - } + RAISE_ADBC(SetupReader(error)); // If the caller did not request a result set or if there are no // inferred output columns (e.g. a CREATE or UPDATE), then don't // use COPY (which would fail anyways) - if (!stream || root_type.n_children() == 0) { + if (!stream || reader_.copy_reader_->pg_type().n_children() == 0) { RAISE_ADBC(ExecuteUpdateQuery(rows_affected, error)); if (stream) { struct ArrowSchema schema; @@ -951,7 +980,8 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // This resolves the reader specific to each PostgresType -> ArrowSchema // conversion. It is unlikely that this will fail given that we have just // inferred these conversions ourselves. - na_res = reader_.copy_reader_->InitFieldReaders(&na_error); + struct ArrowError na_error; + int na_res = reader_.copy_reader_->InitFieldReaders(&na_error); if (na_res != NANOARROW_OK) { SetError(error, "[libpq] Failed to initialize field readers: %s", na_error.message); return na_res; @@ -966,11 +996,12 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, /*paramTypes=*/nullptr, /*paramValues=*/nullptr, /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, kPgBinaryFormat); if (PQresultStatus(reader_.result_) != PGRES_COPY_OUT) { - SetError(error, - "[libpq] Failed to execute query: could not begin COPY: %s\nQuery was: %s", - PQerrorMessage(connection_->conn()), copy_query.c_str()); + AdbcStatusCode code = SetError( + error, reader_.result_, + "[libpq] Failed to execute query: could not begin COPY: %s\nQuery was: %s", + PQerrorMessage(connection_->conn()), copy_query.c_str()); ClearResult(); - return ADBC_STATUS_IO; + return code; } // Result is read from the connection, not the result, but we won't clear it here } @@ -980,6 +1011,23 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::ExecuteSchema(struct ArrowSchema* schema, + struct AdbcError* error) { + ClearResult(); + if (query_.empty()) { + SetError(error, "%s", "[libpq] Must SetSqlQuery before ExecuteQuery"); + return ADBC_STATUS_INVALID_STATE; + } else if (bind_.release) { + // TODO: if we have parameters, bind them (since they can affect the output schema) + SetError(error, "[libpq] ExecuteSchema with parameters is not implemented"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + RAISE_ADBC(SetupReader(error)); + CHECK_NA(INTERNAL, reader_.copy_reader_->GetSchema(schema), error); + return ADBC_STATUS_OK; +} + AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error) { if (!bind_.release) { @@ -991,12 +1039,8 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, std::memset(&bind_, 0, sizeof(bind_)); RAISE_ADBC(bind_stream.Begin( [&]() -> AdbcStatusCode { - if (!ingest_.append) { - // CREATE TABLE - return CreateBulkTable(bind_stream.bind_schema.value, - bind_stream.bind_schema_fields, error); - } - return ADBC_STATUS_OK; + return CreateBulkTable(bind_stream.bind_schema.value, + bind_stream.bind_schema_fields, error); }, error)); RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error)); @@ -1024,17 +1068,77 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateQuery(int64_t* rows_affected, PQexecPrepared(connection_->conn(), /*stmtName=*/"", /*nParams=*/0, /*paramValues=*/nullptr, /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, /*resultFormat=*/kPgBinaryFormat); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to execute query: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); + ExecStatusType status = PQresultStatus(result); + if (status != PGRES_COMMAND_OK && status != PGRES_TUPLES_OK) { + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to execute query: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } if (rows_affected) *rows_affected = PQntuples(reader_.result_); PQclear(result); return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t* length, + struct AdbcError* error) { + std::string result; + if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) { + result = ingest_.target; + } else if (std::strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { + switch (ingest_.mode) { + case IngestMode::kCreate: + result = ADBC_INGEST_OPTION_MODE_CREATE; + break; + case IngestMode::kAppend: + result = ADBC_INGEST_OPTION_MODE_APPEND; + break; + case IngestMode::kReplace: + result = ADBC_INGEST_OPTION_MODE_REPLACE; + break; + case IngestMode::kCreateAppend: + result = ADBC_INGEST_OPTION_MODE_CREATE_APPEND; + break; + } + } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { + result = std::to_string(reader_.batch_size_hint_bytes_); + } else { + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; + } + + if (result.size() + 1 <= *length) { + std::memcpy(value, result.data(), result.size() + 1); + } + *length = static_cast(result.size() + 1); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PostgresStatement::GetOptionBytes(const char* key, uint8_t* value, + size_t* length, + struct AdbcError* error) { + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode PostgresStatement::GetOptionDouble(const char* key, double* value, + struct AdbcError* error) { + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode PostgresStatement::GetOptionInt(const char* key, int64_t* value, + struct AdbcError* error) { + std::string result; + if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { + *value = reader_.batch_size_hint_bytes_; + return ADBC_STATUS_OK; + } + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode PostgresStatement::GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; @@ -1073,16 +1177,22 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) { query_.clear(); ingest_.target = value; + prepared_ = false; } else if (std::strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) { - ingest_.append = false; + ingest_.mode = IngestMode::kCreate; } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) { - ingest_.append = true; + ingest_.mode = IngestMode::kAppend; + } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_REPLACE) == 0) { + ingest_.mode = IngestMode::kReplace; + } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE_APPEND) == 0) { + ingest_.mode = IngestMode::kCreateAppend; } else { SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); return ADBC_STATUS_INVALID_ARGUMENT; } - } else if (std::strcmp(value, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES)) { + prepared_ = false; + } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { int64_t int_value = std::atol(value); if (int_value <= 0) { SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); @@ -1091,12 +1201,84 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, this->reader_.batch_size_hint_bytes_ = int_value; } else { - SetError(error, "[libq] Unknown statement option '%s'", key); + SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_IMPLEMENTED; } return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::SetOptionBytes(const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown statement option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresStatement::SetOptionDouble(const char* key, double value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown statement option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresStatement::SetOptionInt(const char* key, int64_t value, + struct AdbcError* error) { + if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { + if (value <= 0) { + SetError(error, "[libpq] Invalid value '%" PRIi64 "' for option '%s'", value, key); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + this->reader_.batch_size_hint_bytes_ = value; + return ADBC_STATUS_OK; + } + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresStatement::SetupReader(struct AdbcError* error) { + // TODO: we should pipeline here and assume this will succeed + PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(), + /*nParams=*/0, nullptr); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, + "[libpq] Failed to execute query: could not infer schema: failed to " + "prepare query: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); + PQclear(result); + return code; + } + PQclear(result); + result = PQdescribePrepared(connection_->conn(), /*stmtName=*/""); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, + "[libpq] Failed to execute query: could not infer schema: failed to " + "describe prepared statement: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); + PQclear(result); + return code; + } + + // Resolve the information from the PGresult into a PostgresType + PostgresType root_type; + AdbcStatusCode status = ResolvePostgresType(*type_resolver_, result, &root_type, error); + PQclear(result); + if (status != ADBC_STATUS_OK) return status; + + // Initialize the copy reader and infer the output schema (i.e., error for + // unsupported types before issuing the COPY query) + reader_.copy_reader_.reset(new PostgresCopyStreamReader()); + reader_.copy_reader_->Init(root_type); + struct ArrowError na_error; + int na_res = reader_.copy_reader_->InferOutputSchema(&na_error); + if (na_res != NANOARROW_OK) { + SetError(error, "[libpq] Failed to infer output schema: (%d) %s: %s", na_res, + std::strerror(na_res), na_error.message); + return ADBC_STATUS_INTERNAL; + } + return ADBC_STATUS_OK; +} + void PostgresStatement::ClearResult() { // TODO: we may want to synchronize here for safety reader_.Release(); diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h index 0326e80e08..59d3032cf4 100644 --- a/c/driver/postgresql/statement.h +++ b/c/driver/postgresql/statement.h @@ -41,30 +41,28 @@ class PostgresStatement; class TupleReader final { public: TupleReader(PGconn* conn) - : conn_(conn), + : status_(ADBC_STATUS_OK), + error_(ADBC_ERROR_INIT), + conn_(conn), result_(nullptr), pgbuf_(nullptr), copy_reader_(nullptr), row_id_(-1), batch_size_hint_bytes_(16777216), is_finished_(false) { - StringBuilderInit(&error_builder_, 0); data_.data.as_char = nullptr; data_.size_bytes = 0; } int GetSchema(struct ArrowSchema* out); int GetNext(struct ArrowArray* out); - const char* last_error() const { - if (error_builder_.size > 0) { - return error_builder_.buffer; - } else { - return nullptr; - } - } + const char* last_error() const { return error_.message; } void Release(); void ExportTo(struct ArrowArrayStream* stream); + static const struct AdbcError* ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + private: friend class PostgresStatement; @@ -77,11 +75,12 @@ class TupleReader final { static const char* GetLastErrorTrampoline(struct ArrowArrayStream* self); static void ReleaseTrampoline(struct ArrowArrayStream* self); + AdbcStatusCode status_; + struct AdbcError error_; PGconn* conn_; PGresult* result_; char* pgbuf_; struct ArrowBufferView data_; - struct StringBuilder error_builder_; std::unique_ptr copy_reader_; int64_t row_id_; int64_t batch_size_hint_bytes_; @@ -101,13 +100,25 @@ class PostgresStatement { AdbcStatusCode Bind(struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* error); AdbcStatusCode Bind(struct ArrowArrayStream* stream, struct AdbcError* error); + AdbcStatusCode Cancel(struct AdbcError* error); AdbcStatusCode ExecuteQuery(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error); + AdbcStatusCode ExecuteSchema(struct ArrowSchema* schema, struct AdbcError* error); + AdbcStatusCode GetOption(const char* key, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* key, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* key, double* value, struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* key, int64_t* value, struct AdbcError* error); AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error); AdbcStatusCode New(struct AdbcConnection* connection, struct AdbcError* error); AdbcStatusCode Prepare(struct AdbcError* error); AdbcStatusCode Release(struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); AdbcStatusCode SetSqlQuery(const char* query, struct AdbcError* error); // --------------------------------------------------------------------- @@ -123,6 +134,7 @@ class PostgresStatement { AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error); + AdbcStatusCode SetupReader(struct AdbcError* error); private: std::shared_ptr type_resolver_; @@ -134,9 +146,16 @@ class PostgresStatement { struct ArrowArrayStream bind_; // Bulk ingest state + enum class IngestMode { + kCreate, + kAppend, + kReplace, + kCreateAppend, + }; + struct { std::string target; - bool append = false; + IngestMode mode = IngestMode::kCreate; } ingest_; TupleReader reader_; diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index ed2f5de07c..8c3cd72c8b 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -106,11 +106,13 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { } std::string BindParameter(int index) const override { return "?"; } + bool supports_bulk_ingest(const char* /*mode*/) const override { return true; } bool supports_concurrent_statements() const override { return true; } bool supports_transactions() const override { return true; } bool supports_get_sql_info() const override { return false; } bool supports_get_objects() const override { return true; } - bool supports_bulk_ingest() const override { return true; } + bool supports_metadata_current_catalog() const override { return false; } + bool supports_metadata_current_db_schema() const override { return false; } bool supports_partitioned_data() const override { return false; } bool supports_dynamic_parameter_binding() const override { return false; } bool ddl_implicit_commit_txn() const override { return true; } @@ -156,6 +158,10 @@ class SnowflakeConnectionTest : public ::testing::Test, } } + // Supported, but we don't validate the values + void TestMetadataCurrentCatalog() { GTEST_SKIP(); } + void TestMetadataCurrentDbSchema() { GTEST_SKIP(); } + protected: SnowflakeQuirks quirks_; }; diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c index 87e20c998d..7b4072c2ba 100644 --- a/c/driver/sqlite/sqlite.c +++ b/c/driver/sqlite/sqlite.c @@ -86,6 +86,26 @@ AdbcStatusCode SqliteDatabaseSetOption(struct AdbcDatabase* database, const char return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode SqliteDatabaseSetOptionBytes(struct AdbcDatabase* database, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteDatabaseSetOptionDouble(struct AdbcDatabase* database, + const char* key, double value, + struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + int OpenDatabase(const char* maybe_uri, sqlite3** db, struct AdbcError* error) { const char* uri = maybe_uri ? maybe_uri : kDefaultUri; int rc = sqlite3_open_v2(uri, db, @@ -120,6 +140,33 @@ AdbcStatusCode ExecuteQuery(struct SqliteConnection* conn, const char* query, return ADBC_STATUS_OK; } +AdbcStatusCode SqliteDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteDatabaseGetOptionBytes(struct AdbcDatabase* database, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteDatabaseGetOptionDouble(struct AdbcDatabase* database, + const char* key, double* value, + struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode SqliteDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { CHECK_DB_INIT(database, error); @@ -204,6 +251,27 @@ AdbcStatusCode SqliteConnectionSetOption(struct AdbcConnection* connection, return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode SqliteConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode SqliteConnectionInit(struct AdbcConnection* connection, struct AdbcDatabase* database, struct AdbcError* error) { @@ -282,7 +350,8 @@ AdbcStatusCode SqliteConnectionGetInfoImpl(const uint32_t* info_codes, } // NOLINT(whitespace/indent) AdbcStatusCode SqliteConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { CHECK_CONN_INIT(connection, error); @@ -754,6 +823,34 @@ AdbcStatusCode SqliteConnectionGetObjects(struct AdbcConnection* connection, int return BatchToArrayStream(&array, &schema, out, error); } +AdbcStatusCode SqliteConnectionGetOption(struct AdbcConnection* connection, + const char* key, char* value, size_t* length, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode SqliteConnectionGetTableSchema(struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, @@ -1286,6 +1383,34 @@ AdbcStatusCode SqliteStatementBindStream(struct AdbcStatement* statement, return AdbcSqliteBinderSetArrayStream(&stmt->binder, stream, error); } +AdbcStatusCode SqliteStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode SqliteStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -1375,6 +1500,27 @@ AdbcStatusCode SqliteStatementSetOption(struct AdbcStatement* statement, const c return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode SqliteStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode SqliteStatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -1393,7 +1539,7 @@ AdbcStatusCode SqliteDriverInit(int version, void* raw_driver, struct AdbcError* } struct AdbcDriver* driver = (struct AdbcDriver*)raw_driver; - memset(driver, 0, sizeof(*driver)); + memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); driver->DatabaseInit = SqliteDatabaseInit; driver->DatabaseNew = SqliteDatabaseNew; driver->DatabaseRelease = SqliteDatabaseRelease; @@ -1425,24 +1571,91 @@ AdbcStatusCode SqliteDriverInit(int version, void* raw_driver, struct AdbcError* // Public names -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return SqliteDatabaseNew(database, error); +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SqliteDatabaseGetOption(database, key, value, length, error); } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return SqliteDatabaseSetOption(database, key, value, error); +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return SqliteDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return SqliteDatabaseGetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return SqliteDatabaseGetOptionDouble(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return SqliteDatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return SqliteDatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return SqliteDatabaseRelease(database, error); } +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return SqliteDatabaseSetOption(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return SqliteDatabaseSetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return SqliteDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return SqliteDatabaseSetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SqliteConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return SqliteConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return SqliteConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return SqliteConnectionGetOptionDouble(connection, key, value, error); +} + AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, struct AdbcError* error) { return SqliteConnectionNew(connection, error); @@ -1453,6 +1666,24 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const return SqliteConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return SqliteConnectionSetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return SqliteConnectionSetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return SqliteConnectionSetOptionDouble(connection, key, value, error); +} + AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, struct AdbcDatabase* database, struct AdbcError* error) { @@ -1465,7 +1696,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { return SqliteConnectionGetInfo(connection, info_codes, info_codes_length, out, error); @@ -1481,6 +1712,20 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d table_type, column_name, out, error); } +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, @@ -1515,6 +1760,11 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return SqliteConnectionRollback(connection, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, struct AdbcStatement* statement, struct AdbcError* error) { @@ -1533,6 +1783,12 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, return SqliteStatementExecuteQuery(statement, out, rows_affected, error); } +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, struct AdbcError* error) { return SqliteStatementPrepare(statement, error); @@ -1561,6 +1817,29 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return SqliteStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SqliteStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return SqliteStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return SqliteStatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return SqliteStatementGetOptionDouble(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -1572,6 +1851,23 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha return SqliteStatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return SqliteStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return SqliteStatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return SqliteStatementSetOptionDouble(statement, key, value, error); +} + AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc index b03158cca6..d2821a306a 100644 --- a/c/driver/sqlite/sqlite_test.cc +++ b/c/driver/sqlite/sqlite_test.cc @@ -92,7 +92,27 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { return ddl; } + bool supports_bulk_ingest(const char* mode) const override { + return std::strcmp(mode, ADBC_INGEST_OPTION_MODE_APPEND) == 0 || + std::strcmp(mode, ADBC_INGEST_OPTION_MODE_CREATE) == 0; + } bool supports_concurrent_statements() const override { return true; } + bool supports_get_option() const override { return false; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_NAME: + return "ADBC SQLite Driver"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown)"; + case ADBC_INFO_VENDOR_NAME: + return "SQLite"; + case ADBC_INFO_VENDOR_VERSION: + return "3."; + default: + return std::nullopt; + } + } std::string catalog() const override { return "main"; } std::string db_schema() const override { return ""; } diff --git a/c/driver_manager/CMakeLists.txt b/c/driver_manager/CMakeLists.txt index dd28470cf6..6fb51d9a6a 100644 --- a/c/driver_manager/CMakeLists.txt +++ b/c/driver_manager/CMakeLists.txt @@ -55,13 +55,28 @@ if(ADBC_BUILD_TESTS) driver-manager SOURCES adbc_driver_manager_test.cc - ../validation/adbc_validation.cc - ../validation/adbc_validation_util.cc EXTRA_LINK_LIBS adbc_driver_common + adbc_validation nanoarrow ${TEST_LINK_LIBS}) target_compile_features(adbc-driver-manager-test PRIVATE cxx_std_17) target_include_directories(adbc-driver-manager-test SYSTEM PRIVATE ${REPOSITORY_ROOT}/c/vendor/nanoarrow/) + + add_test_case(version_100_compatibility_test + PREFIX + adbc + EXTRA_LABELS + driver-manager + SOURCES + adbc_version_100.c + adbc_version_100_compatibility_test.cc + EXTRA_LINK_LIBS + adbc_validation_util + nanoarrow + ${TEST_LINK_LIBS}) + target_compile_features(adbc-version-100-compatibility-test PRIVATE cxx_std_17) + target_include_directories(adbc-version-100-compatibility-test SYSTEM + PRIVATE ${REPOSITORY_ROOT}/c/vendor/nanoarrow/) endif() diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index d2929e2129..516bf9bbf7 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include #include @@ -90,17 +92,141 @@ void SetError(struct AdbcError* error, const std::string& message) { // Driver state -/// Hold the driver DLL and the driver release callback in the driver struct. -struct ManagerDriverState { - // The original release callback - AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); +/// A driver DLL. +struct ManagedLibrary { + ManagedLibrary() : handle(nullptr) {} + ManagedLibrary(ManagedLibrary&& other) : handle(other.handle) { + other.handle = nullptr; + } + ManagedLibrary(const ManagedLibrary&) = delete; + ManagedLibrary& operator=(const ManagedLibrary&) = delete; + ManagedLibrary& operator=(ManagedLibrary&& other) noexcept { + this->handle = other.handle; + other.handle = nullptr; + return *this; + } + + ~ManagedLibrary() { Release(); } + + void Release() { + // TODO(apache/arrow-adbc#204): causes tests to segfault + // Need to refcount the driver DLL; also, errors may retain a reference to + // release() from the DLL - how to handle this? + } + + AdbcStatusCode Load(const char* library, struct AdbcError* error) { + std::string error_message; +#if defined(_WIN32) + HMODULE handle = LoadLibraryExA(library, NULL, 0); + if (!handle) { + error_message += library; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + + std::string full_driver_name = library; + full_driver_name += ".dll"; + handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); + if (!handle) { + error_message += '\n'; + error_message += full_driver_name; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + } + } + if (!handle) { + SetError(error, error_message); + return ADBC_STATUS_INTERNAL; + } else { + this->handle = handle; + } +#else + static const std::string kPlatformLibraryPrefix = "lib"; +#if defined(__APPLE__) + static const std::string kPlatformLibrarySuffix = ".dylib"; +#else + static const std::string kPlatformLibrarySuffix = ".so"; +#endif // defined(__APPLE__) + + void* handle = dlopen(library, RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message = "dlopen() failed: "; + error_message += dlerror(); + + // If applicable, append the shared library prefix/extension and + // try again (this way you don't have to hardcode driver names by + // platform in the application) + const std::string driver_str = library; + + std::string full_driver_name; + if (driver_str.size() < kPlatformLibraryPrefix.size() || + driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != + 0) { + full_driver_name += kPlatformLibraryPrefix; + } + full_driver_name += library; + if (driver_str.size() < kPlatformLibrarySuffix.size() || + driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix) != 0) { + full_driver_name += kPlatformLibrarySuffix; + } + handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message += "\ndlopen() failed: "; + error_message += dlerror(); + } + } + if (handle) { + this->handle = handle; + } else { + return ADBC_STATUS_INTERNAL; + } +#endif // defined(_WIN32) + return ADBC_STATUS_OK; + } + + AdbcStatusCode Lookup(const char* name, void** func, struct AdbcError* error) { +#if defined(_WIN32) + void* load_handle = reinterpret_cast(GetProcAddress(handle, name)); + if (!load_handle) { + std::string message = "GetProcAddress("; + message += name; + message += ") failed: "; + GetWinError(&message); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } +#else + void* load_handle = dlsym(handle, name); + if (!load_handle) { + std::string message = "dlsym("; + message += name; + message += ") failed: "; + message += dlerror(); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } +#endif // defined(_WIN32) + *func = load_handle; + return ADBC_STATUS_OK; + } #if defined(_WIN32) // The loaded DLL HMODULE handle; +#else + void* handle; #endif // defined(_WIN32) }; +/// Hold the driver DLL and the driver release callback in the driver struct. +struct ManagerDriverState { + // The original release callback + AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); + + ManagedLibrary handle; +}; + /// Unload the driver DLL. static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* error) { AdbcStatusCode status = ADBC_STATUS_OK; @@ -112,35 +238,132 @@ static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* if (state->driver_release) { status = state->driver_release(driver, error); } - -#if defined(_WIN32) - // TODO(apache/arrow-adbc#204): causes tests to segfault - // if (!FreeLibrary(state->handle)) { - // std::string message = "FreeLibrary() failed: "; - // GetWinError(&message); - // SetError(error, message); - // } -#endif // defined(_WIN32) + state->handle.Release(); driver->private_manager = nullptr; delete state; return status; } +// ArrowArrayStream wrapper to support AdbcErrorFromArrayStream + +struct ErrorArrayStream { + struct ArrowArrayStream stream; + struct AdbcDriver* private_driver; +}; + +void ErrorArrayStreamRelease(struct ArrowArrayStream* stream) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return; + + auto* private_data = reinterpret_cast(stream->private_data); + private_data->stream.release(&private_data->stream); + delete private_data; + std::memset(stream, 0, sizeof(*stream)); +} + +const char* ErrorArrayStreamGetLastError(struct ArrowArrayStream* stream) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return nullptr; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_last_error(&private_data->stream); +} + +int ErrorArrayStreamGetNext(struct ArrowArrayStream* stream, struct ArrowArray* array) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_next(&private_data->stream, array); +} + +int ErrorArrayStreamGetSchema(struct ArrowArrayStream* stream, + struct ArrowSchema* schema) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_schema(&private_data->stream, schema); +} + // Default stubs +int ErrorGetDetailCount(const struct AdbcError* error) { return 0; } + +struct AdbcErrorDetail ErrorGetDetail(const struct AdbcError* error, int index) { + return {nullptr, nullptr, 0}; +} + +const struct AdbcError* ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return nullptr; +} + +void ErrorArrayStreamInit(struct ArrowArrayStream* out, + struct AdbcDriver* private_driver) { + if (!out || !out->release || + // Don't bother wrapping if driver didn't claim support + private_driver->ErrorFromArrayStream == ErrorFromArrayStream) { + return; + } + struct ErrorArrayStream* private_data = new ErrorArrayStream; + private_data->stream = *out; + private_data->private_driver = private_driver; + out->get_last_error = ErrorArrayStreamGetLastError; + out->get_next = ErrorArrayStreamGetNext; + out->get_schema = ErrorArrayStreamGetSchema; + out->release = ErrorArrayStreamRelease; + out->private_data = private_data; +} + +AdbcStatusCode DatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode DatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode DatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode DatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode DatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode ConnectionCommit(struct AdbcConnection*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } -AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, uint32_t* info_codes, - size_t info_codes_length, struct ArrowArrayStream* out, - struct AdbcError* error) { +AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, + const uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* out, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } @@ -150,6 +373,39 @@ AdbcStatusCode ConnectionGetObjects(struct AdbcConnection*, int, const char*, co return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode ConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionInt(struct AdbcConnection* connection, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetStatistics(struct AdbcConnection*, const char*, const char*, + const char*, char, struct ArrowArrayStream*, + struct AdbcError*) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, struct AdbcError*) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection*, const char*, const char*, const char*, struct ArrowSchema*, struct AdbcError* error) { @@ -178,11 +434,31 @@ AdbcStatusCode ConnectionSetOption(struct AdbcConnection*, const char*, const ch return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode ConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionSetOptionInt(struct AdbcConnection* connection, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementBind(struct AdbcStatement*, struct ArrowArray*, struct ArrowSchema*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementCancel(struct AdbcStatement* statement, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -191,6 +467,33 @@ AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionBytes(struct AdbcStatement* statement, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionDouble(struct AdbcStatement* statement, const char* key, + double* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -206,6 +509,21 @@ AdbcStatusCode StatementSetOption(struct AdbcStatement*, const char*, const char return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementSetOptionBytes(struct AdbcStatement*, const char*, const uint8_t*, + size_t, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementSetOptionDouble(struct AdbcStatement* statement, const char* key, + double value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement*, const char*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; @@ -219,20 +537,129 @@ AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement*, const uint8_t*, /// Temporary state while the database is being configured. struct TempDatabase { std::unordered_map options; + std::unordered_map bytes_options; + std::unordered_map int_options; + std::unordered_map double_options; std::string driver; - // Default name (see adbc.h) - std::string entrypoint = "AdbcDriverInit"; + std::string entrypoint; AdbcDriverInitFunc init_func = nullptr; }; /// Temporary state while the database is being configured. struct TempConnection { std::unordered_map options; + std::unordered_map bytes_options; + std::unordered_map int_options; + std::unordered_map double_options; }; + +static const char kDefaultEntrypoint[] = "AdbcDriverInit"; } // namespace +// Other helpers (intentionally not in an anonymous namespace so they can be tested) + +ADBC_EXPORT +std::string AdbcDriverManagerDefaultEntrypoint(const std::string& driver) { + /// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit + /// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit + /// - proprietary_driver.dll -> AdbcProprietaryDriverInit + + // Potential path -> filename + // Treat both \ and / as directory separators on all platforms for simplicity + std::string filename; + { + size_t pos = driver.find_last_of("/\\"); + if (pos != std::string::npos) { + filename = driver.substr(pos + 1); + } else { + filename = driver; + } + } + + // Remove all extensions + { + size_t pos = filename.find('.'); + if (pos != std::string::npos) { + filename = filename.substr(0, pos); + } + } + + // Remove lib prefix + // https://stackoverflow.com/q/1878001/262727 + if (filename.rfind("lib", 0) == 0) { + filename = filename.substr(3); + } + + // Split on underscores, hyphens + // Capitalize and join + std::string entrypoint; + entrypoint.reserve(filename.size()); + size_t pos = 0; + while (pos < filename.size()) { + size_t prev = pos; + pos = filename.find_first_of("-_", pos); + // if pos == npos this is the entire filename + std::string token = filename.substr(prev, pos - prev); + // capitalize first letter + token[0] = std::toupper(static_cast(token[0])); + + entrypoint += token; + + if (pos != std::string::npos) { + pos++; + } + } + + if (entrypoint.rfind("Adbc", 0) != 0) { + entrypoint = "Adbc" + entrypoint; + } + entrypoint += "Init"; + + return entrypoint; +} + // Direct implementations of API methods +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && + error->private_driver) { + return error->private_driver->ErrorGetDetailCount(error); + } + return 0; +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && + error->private_driver) { + return error->private_driver->ErrorGetDetail(error, index); + } + return {nullptr, nullptr, 0}; +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + if (!stream->private_data || stream->release != ErrorArrayStreamRelease) { + return nullptr; + } + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->private_driver->ErrorFromArrayStream(&private_data->stream, + status); +} + +#define INIT_ERROR(ERROR, SOURCE) \ + if ((ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ + (ERROR)->private_driver = (SOURCE)->private_driver; \ + } + +#define WRAP_STREAM(EXPR, OUT, SOURCE) \ + if (!(OUT)) { \ + /* Happens for ExecuteQuery where out is optional */ \ + return EXPR; \ + } \ + AdbcStatusCode status_code = EXPR; \ + ErrorArrayStreamInit(OUT, (SOURCE)->private_driver); \ + return status_code; + AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { // Allocate a temporary structure to store options pre-Init database->private_data = new TempDatabase(); @@ -240,9 +667,93 @@ AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* return ADBC_STATUS_OK; } +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOption(database, key, value, length, + error); + } + const auto* args = reinterpret_cast(database->private_data); + const std::string* result = nullptr; + if (std::strcmp(key, "driver") == 0) { + result = &args->driver; + } else if (std::strcmp(key, "entrypoint") == 0) { + result = &args->entrypoint; + } else { + const auto it = args->options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + result = &it->second; + } + + if (*length <= result->size() + 1) { + // Enough space + std::memcpy(value, result->c_str(), result->size() + 1); + } + *length = result->size() + 1; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionBytes(database, key, value, length, + error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->bytes_options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + const std::string& result = it->second; + + if (*length <= result.size()) { + // Enough space + std::memcpy(value, result.c_str(), result.size()); + } + *length = result.size(); + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionInt(database, key, value, error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->int_options.find(key); + if (it == args->int_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionDouble(database, key, value, error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->double_options.find(key); + if (it == args->double_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { if (database->private_driver) { + INIT_ERROR(error, database); return database->private_driver->DatabaseSetOption(database, key, value, error); } @@ -257,6 +768,44 @@ AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* return ADBC_STATUS_OK; } +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionBytes(database, key, value, length, + error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->bytes_options[key] = std::string(reinterpret_cast(value), length); + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionInt(database, key, value, error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->int_options[key] = value; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionDouble(database, key, value, error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->double_options[key] = value; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* database, AdbcDriverInitFunc init_func, struct AdbcError* error) { @@ -288,11 +837,14 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* // So we don't confuse a driver into thinking it's initialized already database->private_data = nullptr; if (args->init_func) { - status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_0_0, + status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_1_0, database->private_driver, error); - } else { + } else if (!args->entrypoint.empty()) { status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), - ADBC_VERSION_1_0_0, database->private_driver, error); + ADBC_VERSION_1_1_0, database->private_driver, error); + } else { + status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, + database->private_driver, error); } if (status != ADBC_STATUS_OK) { // Restore private_data so it will be released by AdbcDatabaseRelease @@ -313,25 +865,49 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_driver = nullptr; return status; } - for (const auto& option : args->options) { + auto options = std::move(args->options); + auto bytes_options = std::move(args->bytes_options); + auto int_options = std::move(args->int_options); + auto double_options = std::move(args->double_options); + delete args; + + INIT_ERROR(error, database); + for (const auto& option : options) { status = database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); - if (status != ADBC_STATUS_OK) { - delete args; - // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - // Should be redundant, but ensure that AdbcDatabaseRelease - // below doesn't think that it contains a TempDatabase - database->private_data = nullptr; - return status; + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : bytes_options) { + status = database->private_driver->DatabaseSetOptionBytes( + database, option.first.c_str(), + reinterpret_cast(option.second.data()), option.second.size(), + error); + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : int_options) { + status = database->private_driver->DatabaseSetOptionInt( + database, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : double_options) { + status = database->private_driver->DatabaseSetOptionDouble( + database, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) break; + } + + if (status != ADBC_STATUS_OK) { + // Release the database + std::ignore = database->private_driver->DatabaseRelease(database, error); + if (database->private_driver->release) { + database->private_driver->release(database->private_driver, error); } + delete database->private_driver; + database->private_driver = nullptr; + // Should be redundant, but ensure that AdbcDatabaseRelease + // below doesn't think that it contains a TempDatabase + database->private_data = nullptr; + return status; } - delete args; return database->private_driver->DatabaseInit(database, error); } @@ -346,6 +922,7 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, } return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, database); auto status = database->private_driver->DatabaseRelease(database, error); if (database->private_driver->release) { database->private_driver->release(database->private_driver, error); @@ -356,23 +933,35 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, return status; } +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionCancel(connection, error); +} + AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetInfo(connection, info_codes, - info_codes_length, out, error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionGetInfo( + connection, info_codes, info_codes_length, out, error), + out, connection); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -384,9 +973,132 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetObjects( - connection, depth, catalog, db_schema, table_name, table_types, column_name, stream, - error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionGetObjects( + connection, depth, catalog, db_schema, table_name, table_types, + column_name, stream, error), + stream, connection); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + if (*length >= it->second.size() + 1) { + std::memcpy(value, it->second.c_str(), it->second.size() + 1); + } + *length = it->second.size() + 1; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOption(connection, key, value, length, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->bytes_options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + if (*length >= it->second.size() + 1) { + std::memcpy(value, it->second.data(), it->second.size() + 1); + } + *length = it->second.size() + 1; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionBytes(connection, key, value, + length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->int_options.find(key); + if (it == args->int_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionInt(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->double_options.find(key); + if (it == args->double_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionDouble(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetStatistics( + connection, catalog, db_schema, table_name, approximate == 1, out, error), + out, connection); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetStatisticNames(connection, out, error), + out, connection); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -397,6 +1109,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionGetTableSchema( connection, catalog, db_schema, table_name, schema, error); } @@ -407,7 +1120,10 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetTableTypes(connection, stream, error), + stream, connection); } AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, @@ -423,6 +1139,11 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, TempConnection* args = reinterpret_cast(connection->private_data); connection->private_data = nullptr; std::unordered_map options = std::move(args->options); + std::unordered_map bytes_options = + std::move(args->bytes_options); + std::unordered_map int_options = std::move(args->int_options); + std::unordered_map double_options = + std::move(args->double_options); delete args; auto status = database->private_driver->ConnectionNew(connection, error); @@ -434,6 +1155,24 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, connection, option.first.c_str(), option.second.c_str(), error); if (status != ADBC_STATUS_OK) return status; } + for (const auto& option : bytes_options) { + status = database->private_driver->ConnectionSetOptionBytes( + connection, option.first.c_str(), + reinterpret_cast(option.second.data()), option.second.size(), + error); + if (status != ADBC_STATUS_OK) return status; + } + for (const auto& option : int_options) { + status = database->private_driver->ConnectionSetOptionInt( + connection, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) return status; + } + for (const auto& option : double_options) { + status = database->private_driver->ConnectionSetOptionDouble( + connection, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) return status; + } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionInit(connection, database, error); } @@ -455,8 +1194,10 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionReadPartition( - connection, serialized_partition, serialized_length, out, error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionReadPartition( + connection, serialized_partition, serialized_length, out, error), + out, connection); } AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, @@ -470,6 +1211,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, } return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); auto status = connection->private_driver->ConnectionRelease(connection, error); connection->private_driver = nullptr; return status; @@ -480,6 +1222,7 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionRollback(connection, error); } @@ -495,15 +1238,71 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const args->options[key] = value; return ADBC_STATUS_OK; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->bytes_options[key] = std::string(reinterpret_cast(value), length); + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionBytes(connection, key, value, + length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->int_options[key] = value; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionInt(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionDouble: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->double_options[key] = value; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionDouble(connection, key, value, + error); +} + AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* error) { if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementBind(statement, values, schema, error); } @@ -513,9 +1312,19 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementCancel(statement, error); +} + // XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, ArrowSchema* schema, @@ -525,6 +1334,7 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementExecutePartitions( statement, schema, partitions, rows_affected, error); } @@ -536,8 +1346,62 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, - error); + INIT_ERROR(error, statement); + WRAP_STREAM(statement->private_driver->StatementExecuteQuery(statement, out, + rows_affected, error), + out, statement); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOption(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionBytes(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionDouble(statement, key, value, + error); } AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, @@ -546,6 +1410,7 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementGetParameterSchema(statement, schema, error); } @@ -555,6 +1420,7 @@ AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); auto status = connection->private_driver->StatementNew(connection, statement, error); statement->private_driver = connection->private_driver; return status; @@ -565,6 +1431,7 @@ AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementPrepare(statement, error); } @@ -573,6 +1440,7 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); auto status = statement->private_driver->StatementRelease(statement, error); statement->private_driver = nullptr; return status; @@ -583,14 +1451,47 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionBytes(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionDouble(statement, key, value, + error); +} + AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetSqlQuery(statement, query, error); } @@ -600,6 +1501,7 @@ AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); } @@ -636,137 +1538,80 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, AdbcDriverInitFunc init_func; std::string error_message; - if (version != ADBC_VERSION_1_0_0) { - SetError(error, "Only ADBC 1.0.0 is supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - auto* driver = reinterpret_cast(raw_driver); - - if (!entrypoint) { - // Default entrypoint (see adbc.h) - entrypoint = "AdbcDriverInit"; + switch (version) { + case ADBC_VERSION_1_0_0: + case ADBC_VERSION_1_1_0: + break; + default: + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; } -#if defined(_WIN32) - - HMODULE handle = LoadLibraryExA(driver_name, NULL, 0); - if (!handle) { - error_message += driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - - std::string full_driver_name = driver_name; - full_driver_name += ".lib"; - handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); - if (!handle) { - error_message += '\n'; - error_message += full_driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - } - } - if (!handle) { - SetError(error, error_message); - return ADBC_STATUS_INTERNAL; + if (!raw_driver) { + SetError(error, "Must provide non-NULL raw_driver"); + return ADBC_STATUS_INVALID_ARGUMENT; } + auto* driver = reinterpret_cast(raw_driver); - void* load_handle = reinterpret_cast(GetProcAddress(handle, entrypoint)); - init_func = reinterpret_cast(load_handle); - if (!init_func) { - std::string message = "GetProcAddress("; - message += entrypoint; - message += ") failed: "; - GetWinError(&message); - if (!FreeLibrary(handle)) { - message += "\nFreeLibrary() failed: "; - GetWinError(&message); - } - SetError(error, message); - return ADBC_STATUS_INTERNAL; + ManagedLibrary library; + AdbcStatusCode status = library.Load(driver_name, error); + if (status != ADBC_STATUS_OK) { + // AdbcDatabaseInit tries to call this if set + driver->release = nullptr; + return status; } -#else - -#if defined(__APPLE__) - static const std::string kPlatformLibraryPrefix = "lib"; - static const std::string kPlatformLibrarySuffix = ".dylib"; -#else - static const std::string kPlatformLibraryPrefix = "lib"; - static const std::string kPlatformLibrarySuffix = ".so"; -#endif // defined(__APPLE__) - - void* handle = dlopen(driver_name, RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message = "dlopen() failed: "; - error_message += dlerror(); - - // If applicable, append the shared library prefix/extension and - // try again (this way you don't have to hardcode driver names by - // platform in the application) - const std::string driver_str = driver_name; - - std::string full_driver_name; - if (driver_str.size() < kPlatformLibraryPrefix.size() || - driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != - 0) { - full_driver_name += kPlatformLibraryPrefix; - } - full_driver_name += driver_name; - if (driver_str.size() < kPlatformLibrarySuffix.size() || - driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), - kPlatformLibrarySuffix.size(), kPlatformLibrarySuffix) != 0) { - full_driver_name += kPlatformLibrarySuffix; - } - handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message += "\ndlopen() failed: "; - error_message += dlerror(); + void* load_handle = nullptr; + if (entrypoint) { + status = library.Lookup(entrypoint, &load_handle, error); + } else { + auto name = AdbcDriverManagerDefaultEntrypoint(driver_name); + status = library.Lookup(name.c_str(), &load_handle, error); + if (status != ADBC_STATUS_OK) { + status = library.Lookup(kDefaultEntrypoint, &load_handle, error); } } - if (!handle) { - SetError(error, error_message); - // AdbcDatabaseInit tries to call this if set - driver->release = nullptr; - return ADBC_STATUS_INTERNAL; - } - void* load_handle = dlsym(handle, entrypoint); - if (!load_handle) { - std::string message = "dlsym("; - message += entrypoint; - message += ") failed: "; - message += dlerror(); - SetError(error, message); - return ADBC_STATUS_INTERNAL; + if (status != ADBC_STATUS_OK) { + library.Release(); + return status; } init_func = reinterpret_cast(load_handle); -#endif // defined(_WIN32) - - AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); + status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); if (status == ADBC_STATUS_OK) { ManagerDriverState* state = new ManagerDriverState; state->driver_release = driver->release; -#if defined(_WIN32) - state->handle = handle; -#endif // defined(_WIN32) + state->handle = std::move(library); driver->release = &ReleaseDriver; driver->private_manager = state; } else { -#if defined(_WIN32) - if (!FreeLibrary(handle)) { - std::string message = "FreeLibrary() failed: "; - GetWinError(&message); - SetError(error, message); - } -#endif // defined(_WIN32) + library.Release(); } return status; } AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void* raw_driver, struct AdbcError* error) { + constexpr std::array kSupportedVersions = { + ADBC_VERSION_1_1_0, + ADBC_VERSION_1_0_0, + }; + + if (!raw_driver) { + SetError(error, "Must provide non-NULL raw_driver"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + switch (version) { + case ADBC_VERSION_1_0_0: + case ADBC_VERSION_1_1_0: + break; + default: + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + #define FILL_DEFAULT(DRIVER, STUB) \ if (!DRIVER->STUB) { \ DRIVER->STUB = &STUB; \ @@ -777,12 +1622,20 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers return ADBC_STATUS_INTERNAL; \ } - auto result = init_func(version, raw_driver, error); + // Starting from the passed version, try each (older) version in + // succession with the underlying driver until we find one that's + // accepted. + AdbcStatusCode result = ADBC_STATUS_NOT_IMPLEMENTED; + for (const int try_version : kSupportedVersions) { + if (try_version > version) continue; + result = init_func(try_version, raw_driver, error); + if (result != ADBC_STATUS_NOT_IMPLEMENTED) break; + } if (result != ADBC_STATUS_OK) { return result; } - if (version == ADBC_VERSION_1_0_0) { + if (version >= ADBC_VERSION_1_0_0) { auto* driver = reinterpret_cast(raw_driver); CHECK_REQUIRED(driver, DatabaseNew); CHECK_REQUIRED(driver, DatabaseInit); @@ -812,6 +1665,41 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers FILL_DEFAULT(driver, StatementSetSqlQuery); FILL_DEFAULT(driver, StatementSetSubstraitPlan); } + if (version >= ADBC_VERSION_1_1_0) { + auto* driver = reinterpret_cast(raw_driver); + FILL_DEFAULT(driver, ErrorGetDetailCount); + FILL_DEFAULT(driver, ErrorGetDetail); + FILL_DEFAULT(driver, ErrorFromArrayStream); + + FILL_DEFAULT(driver, DatabaseGetOption); + FILL_DEFAULT(driver, DatabaseGetOptionBytes); + FILL_DEFAULT(driver, DatabaseGetOptionDouble); + FILL_DEFAULT(driver, DatabaseGetOptionInt); + FILL_DEFAULT(driver, DatabaseSetOptionBytes); + FILL_DEFAULT(driver, DatabaseSetOptionDouble); + FILL_DEFAULT(driver, DatabaseSetOptionInt); + + FILL_DEFAULT(driver, ConnectionCancel); + FILL_DEFAULT(driver, ConnectionGetOption); + FILL_DEFAULT(driver, ConnectionGetOptionBytes); + FILL_DEFAULT(driver, ConnectionGetOptionDouble); + FILL_DEFAULT(driver, ConnectionGetOptionInt); + FILL_DEFAULT(driver, ConnectionGetStatistics); + FILL_DEFAULT(driver, ConnectionGetStatisticNames); + FILL_DEFAULT(driver, ConnectionSetOptionBytes); + FILL_DEFAULT(driver, ConnectionSetOptionDouble); + FILL_DEFAULT(driver, ConnectionSetOptionInt); + + FILL_DEFAULT(driver, StatementCancel); + FILL_DEFAULT(driver, StatementExecuteSchema); + FILL_DEFAULT(driver, StatementGetOption); + FILL_DEFAULT(driver, StatementGetOptionBytes); + FILL_DEFAULT(driver, StatementGetOptionDouble); + FILL_DEFAULT(driver, StatementGetOptionInt); + FILL_DEFAULT(driver, StatementSetOptionBytes); + FILL_DEFAULT(driver, StatementSetOptionDouble); + FILL_DEFAULT(driver, StatementSetOptionInt); + } return ADBC_STATUS_OK; diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc index d3ff6f58e1..5e70390388 100644 --- a/c/driver_manager/adbc_driver_manager_test.cc +++ b/c/driver_manager/adbc_driver_manager_test.cc @@ -27,6 +27,8 @@ #include "validation/adbc_validation.h" #include "validation/adbc_validation_util.h" +std::string AdbcDriverManagerDefaultEntrypoint(const std::string& filename); + // Tests of the SQLite example driver, except using the driver manager namespace adbc { @@ -40,7 +42,7 @@ class DriverManager : public ::testing::Test { std::memset(&driver, 0, sizeof(driver)); std::memset(&error, 0, sizeof(error)); - ASSERT_THAT(AdbcLoadDriver("adbc_driver_sqlite", nullptr, ADBC_VERSION_1_0_0, &driver, + ASSERT_THAT(AdbcLoadDriver("adbc_driver_sqlite", nullptr, ADBC_VERSION_1_1_0, &driver, &error), IsOkStatus(&error)); } @@ -191,7 +193,27 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { } } + bool supports_bulk_ingest(const char* mode) const override { + return std::strcmp(mode, ADBC_INGEST_OPTION_MODE_APPEND) == 0 || + std::strcmp(mode, ADBC_INGEST_OPTION_MODE_CREATE) == 0; + } bool supports_concurrent_statements() const override { return true; } + bool supports_get_option() const override { return false; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_NAME: + return "ADBC SQLite Driver"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown)"; + case ADBC_INFO_VENDOR_NAME: + return "SQLite"; + case ADBC_INFO_VENDOR_VERSION: + return "3."; + default: + return std::nullopt; + } + } }; class SqliteDatabaseTest : public ::testing::Test, public adbc_validation::DatabaseTest { @@ -242,4 +264,41 @@ class SqliteStatementTest : public ::testing::Test, }; ADBCV_TEST_STATEMENT(SqliteStatementTest) +TEST(AdbcDriverManagerInternal, AdbcDriverManagerDefaultEntrypoint) { + for (const auto& driver : { + "adbc_driver_sqlite", + "adbc_driver_sqlite.dll", + "driver_sqlite", + "libadbc_driver_sqlite", + "libadbc_driver_sqlite.so", + "libadbc_driver_sqlite.so.6.0.0", + "/usr/lib/libadbc_driver_sqlite.so", + "/usr/lib/libadbc_driver_sqlite.so.6.0.0", + "C:\\System32\\adbc_driver_sqlite.dll", + }) { + SCOPED_TRACE(driver); + EXPECT_EQ("AdbcDriverSqliteInit", ::AdbcDriverManagerDefaultEntrypoint(driver)); + } + + for (const auto& driver : { + "adbc_sqlite", + "sqlite", + "/usr/lib/sqlite.so", + "C:\\System32\\sqlite.dll", + }) { + SCOPED_TRACE(driver); + EXPECT_EQ("AdbcSqliteInit", ::AdbcDriverManagerDefaultEntrypoint(driver)); + } + + for (const auto& driver : { + "proprietary_engine", + "libproprietary_engine.so.6.0.0", + "/usr/lib/proprietary_engine.so", + "C:\\System32\\proprietary_engine.dll", + }) { + SCOPED_TRACE(driver); + EXPECT_EQ("AdbcProprietaryEngineInit", ::AdbcDriverManagerDefaultEntrypoint(driver)); + } +} + } // namespace adbc diff --git a/c/driver_manager/adbc_version_100.c b/c/driver_manager/adbc_version_100.c new file mode 100644 index 0000000000..48114cdb43 --- /dev/null +++ b/c/driver_manager/adbc_version_100.c @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "adbc_version_100.h" + +#include + +struct Version100Database { + int dummy; +}; + +static struct Version100Database kDatabase; + +struct Version100Connection { + int dummy; +}; + +static struct Version100Connection kConnection; + +struct Version100Statement { + int dummy; +}; + +static struct Version100Statement kStatement; + +AdbcStatusCode Version100DatabaseInit(struct AdbcDatabase* database, + struct AdbcError* error) { + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100DatabaseNew(struct AdbcDatabase* database, + struct AdbcError* error) { + database->private_data = &kDatabase; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100DatabaseRelease(struct AdbcDatabase* database, + struct AdbcError* error) { + database->private_data = NULL; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100ConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100ConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + connection->private_data = &kConnection; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100StatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* stream, + int64_t* rows_affected, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode Version100StatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + statement->private_data = &kStatement; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100StatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + statement->private_data = NULL; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100ConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + connection->private_data = NULL; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100DriverInit(int version, void* raw_driver, + struct AdbcError* error) { + if (version != ADBC_VERSION_1_0_0) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + struct AdbcDriverVersion100* driver = (struct AdbcDriverVersion100*)raw_driver; + memset(driver, 0, sizeof(struct AdbcDriverVersion100)); + + driver->DatabaseInit = &Version100DatabaseInit; + driver->DatabaseNew = &Version100DatabaseNew; + driver->DatabaseRelease = &Version100DatabaseRelease; + + driver->ConnectionInit = &Version100ConnectionInit; + driver->ConnectionNew = &Version100ConnectionNew; + driver->ConnectionRelease = &Version100ConnectionRelease; + + driver->StatementExecuteQuery = &Version100StatementExecuteQuery; + driver->StatementNew = &Version100StatementNew; + driver->StatementRelease = &Version100StatementRelease; + + return ADBC_STATUS_OK; +} diff --git a/c/driver_manager/adbc_version_100.h b/c/driver_manager/adbc_version_100.h new file mode 100644 index 0000000000..b349f86f73 --- /dev/null +++ b/c/driver_manager/adbc_version_100.h @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// A dummy version 1.0.0 ADBC driver to test compatibility. + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct AdbcErrorVersion100 { + char* message; + int32_t vendor_code; + char sqlstate[5]; + void (*release)(struct AdbcError* error); +}; + +struct AdbcDriverVersion100 { + void* private_data; + void* private_manager; + AdbcStatusCode (*release)(struct AdbcDriver* driver, struct AdbcError* error); + + AdbcStatusCode (*DatabaseInit)(struct AdbcDatabase*, struct AdbcError*); + AdbcStatusCode (*DatabaseNew)(struct AdbcDatabase*, struct AdbcError*); + AdbcStatusCode (*DatabaseSetOption)(struct AdbcDatabase*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*); + + AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection*, int, const char*, + const char*, const char*, const char**, + const char*, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionGetTableSchema)(struct AdbcConnection*, const char*, + const char*, const char*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetTableTypes)(struct AdbcConnection*, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct AdbcDatabase*, + struct AdbcError*); + AdbcStatusCode (*ConnectionNew)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionSetOption)(struct AdbcConnection*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*ConnectionReadPartition)(struct AdbcConnection*, const uint8_t*, + size_t, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionRollback)(struct AdbcConnection*, struct AdbcError*); + + AdbcStatusCode (*StatementBind)(struct AdbcStatement*, struct ArrowArray*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*StatementBindStream)(struct AdbcStatement*, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*StatementExecuteQuery)(struct AdbcStatement*, struct ArrowArrayStream*, + int64_t*, struct AdbcError*); + AdbcStatusCode (*StatementExecutePartitions)(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcPartitions*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*StatementGetParameterSchema)(struct AdbcStatement*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*StatementNew)(struct AdbcConnection*, struct AdbcStatement*, + struct AdbcError*); + AdbcStatusCode (*StatementPrepare)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementRelease)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementSetOption)(struct AdbcStatement*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*StatementSetSqlQuery)(struct AdbcStatement*, const char*, + struct AdbcError*); + AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const uint8_t*, + size_t, struct AdbcError*); +}; + +AdbcStatusCode Version100DriverInit(int version, void* driver, struct AdbcError* error); + +#ifdef __cplusplus +} +#endif diff --git a/c/driver_manager/adbc_version_100_compatibility_test.cc b/c/driver_manager/adbc_version_100_compatibility_test.cc new file mode 100644 index 0000000000..27e5f5d997 --- /dev/null +++ b/c/driver_manager/adbc_version_100_compatibility_test.cc @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include + +#include "adbc.h" +#include "adbc_driver_manager.h" +#include "adbc_version_100.h" +#include "validation/adbc_validation_util.h" + +namespace adbc { + +using adbc_validation::IsOkStatus; +using adbc_validation::IsStatus; + +class AdbcVersion : public ::testing::Test { + public: + void SetUp() override { + std::memset(&driver, 0, sizeof(driver)); + std::memset(&error, 0, sizeof(error)); + } + + void TearDown() override { + if (error.release) { + error.release(&error); + } + + if (driver.release) { + ASSERT_THAT(driver.release(&driver, &error), IsOkStatus(&error)); + ASSERT_EQ(driver.private_data, nullptr); + ASSERT_EQ(driver.private_manager, nullptr); + } + } + + protected: + struct AdbcDriver driver = {}; + struct AdbcError error = {}; +}; + +TEST_F(AdbcVersion, StructSize) { + ASSERT_EQ(sizeof(AdbcErrorVersion100), ADBC_ERROR_1_0_0_SIZE); + ASSERT_EQ(sizeof(AdbcError), ADBC_ERROR_1_1_0_SIZE); + + ASSERT_EQ(sizeof(AdbcDriverVersion100), ADBC_DRIVER_1_0_0_SIZE); + ASSERT_EQ(sizeof(AdbcDriver), ADBC_DRIVER_1_1_0_SIZE); +} + +// Initialize a version 1.0.0 driver with the version 1.1.0 driver struct. +TEST_F(AdbcVersion, OldDriverNewLayout) { + ASSERT_THAT(Version100DriverInit(ADBC_VERSION_1_1_0, &driver, &error), + IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &error)); + + ASSERT_THAT(Version100DriverInit(ADBC_VERSION_1_0_0, &driver, &error), + IsOkStatus(&error)); +} + +// Initialize a version 1.0.0 driver with the new driver manager/new version. +TEST_F(AdbcVersion, OldDriverNewManager) { + ASSERT_THAT(AdbcLoadDriverFromInitFunc(&Version100DriverInit, ADBC_VERSION_1_1_0, + &driver, &error), + IsOkStatus(&error)); + + EXPECT_NE(driver.ErrorGetDetailCount, nullptr); + EXPECT_NE(driver.ErrorGetDetail, nullptr); + + EXPECT_NE(driver.DatabaseGetOption, nullptr); + EXPECT_NE(driver.DatabaseGetOptionBytes, nullptr); + EXPECT_NE(driver.DatabaseGetOptionDouble, nullptr); + EXPECT_NE(driver.DatabaseGetOptionInt, nullptr); + EXPECT_NE(driver.DatabaseSetOptionInt, nullptr); + EXPECT_NE(driver.DatabaseSetOptionDouble, nullptr); + + EXPECT_NE(driver.ConnectionCancel, nullptr); + EXPECT_NE(driver.ConnectionGetOption, nullptr); + EXPECT_NE(driver.ConnectionGetOptionBytes, nullptr); + EXPECT_NE(driver.ConnectionGetOptionDouble, nullptr); + EXPECT_NE(driver.ConnectionGetOptionInt, nullptr); + EXPECT_NE(driver.ConnectionSetOptionInt, nullptr); + EXPECT_NE(driver.ConnectionSetOptionDouble, nullptr); + + EXPECT_NE(driver.StatementCancel, nullptr); + EXPECT_NE(driver.StatementExecuteSchema, nullptr); + EXPECT_NE(driver.StatementGetOption, nullptr); + EXPECT_NE(driver.StatementGetOptionBytes, nullptr); + EXPECT_NE(driver.StatementGetOptionDouble, nullptr); + EXPECT_NE(driver.StatementGetOptionInt, nullptr); + EXPECT_NE(driver.StatementSetOptionInt, nullptr); + EXPECT_NE(driver.StatementSetOptionDouble, nullptr); +} + +// N.B. see postgresql_test.cc for backwards compatibility test of AdbcError +// N.B. see postgresql_test.cc for backwards compatibility test of AdbcDriver + +} // namespace adbc diff --git a/c/integration/duckdb/CMakeLists.txt b/c/integration/duckdb/CMakeLists.txt index 52fb9d0f8c..8053713f3d 100644 --- a/c/integration/duckdb/CMakeLists.txt +++ b/c/integration/duckdb/CMakeLists.txt @@ -49,6 +49,7 @@ if(ADBC_BUILD_TESTS) CACHE INTERNAL "Disable UBSAN") # Force cmake to honor our options here in the subproject cmake_policy(SET CMP0077 NEW) + message(STATUS "Fetching DuckDB") fetchcontent_makeavailable(duckdb) include_directories(SYSTEM ${REPOSITORY_ROOT}) diff --git a/c/integration/duckdb/duckdb_test.cc b/c/integration/duckdb/duckdb_test.cc index fd6e1984e6..c42c3813cf 100644 --- a/c/integration/duckdb/duckdb_test.cc +++ b/c/integration/duckdb/duckdb_test.cc @@ -46,7 +46,7 @@ class DuckDbQuirks : public adbc_validation::DriverQuirks { std::string BindParameter(int index) const override { return "?"; } - bool supports_bulk_ingest() const override { return false; } + bool supports_bulk_ingest(const char* /*mode*/) const override { return false; } bool supports_concurrent_statements() const override { return true; } bool supports_dynamic_parameter_binding() const override { return false; } bool supports_get_sql_info() const override { return false; } @@ -96,6 +96,12 @@ class DuckDbStatementTest : public ::testing::Test, void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; } + void TestSqlQueryErrors() { GTEST_SKIP() << "DuckDB does not set AdbcError.release"; } + + void TestErrorCompatibility() { + GTEST_SKIP() << "DuckDB does not set AdbcError.release"; + } + protected: DuckDbQuirks quirks_; }; diff --git a/c/symbols.map b/c/symbols.map index 5e965b355e..c9464b2da4 100644 --- a/c/symbols.map +++ b/c/symbols.map @@ -20,6 +20,16 @@ # Only expose symbols from the ADBC API Adbc*; + # Expose driver-specific initialization routines + FlightSQLDriverInit; + PostgresqlDriverInit; + SnowflakeDriverInit; + SqliteDriverInit; + + extern "C++" { + Adbc*; + }; + local: *; }; diff --git a/c/validation/CMakeLists.txt b/c/validation/CMakeLists.txt index 2f6549b5e7..bab7a63b19 100644 --- a/c/validation/CMakeLists.txt +++ b/c/validation/CMakeLists.txt @@ -15,11 +15,24 @@ # specific language governing permissions and limitations # under the License. -add_library(adbc_validation OBJECT adbc_validation.cc adbc_validation_util.cc) +add_library(adbc_validation_util STATIC adbc_validation_util.cc) +adbc_configure_target(adbc_validation_util) +target_compile_features(adbc_validation_util PRIVATE cxx_std_17) +target_include_directories(adbc_validation_util SYSTEM + PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/driver/" + "${REPOSITORY_ROOT}/c/vendor/") +target_link_libraries(adbc_validation_util PUBLIC adbc_driver_common nanoarrow + GTest::gtest GTest::gmock) + +add_library(adbc_validation OBJECT adbc_validation.cc) adbc_configure_target(adbc_validation) target_compile_features(adbc_validation PRIVATE cxx_std_17) target_include_directories(adbc_validation SYSTEM PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/driver/" "${REPOSITORY_ROOT}/c/vendor/") -target_link_libraries(adbc_validation PUBLIC adbc_driver_common nanoarrow GTest::gtest - GTest::gmock) +target_link_libraries(adbc_validation + PUBLIC adbc_driver_common + adbc_validation_util + nanoarrow + GTest::gtest + GTest::gmock) diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc index 8f519b0417..afb0260a63 100644 --- a/c/validation/adbc_validation.cc +++ b/c/validation/adbc_validation.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,7 @@ #include #include #include +#include #include "adbc_validation_util.h" @@ -101,7 +103,7 @@ AdbcStatusCode DriverQuirks::EnsureSampleTable(struct AdbcConnection* connection AdbcStatusCode DriverQuirks::CreateSampleTable(struct AdbcConnection* connection, const std::string& name, struct AdbcError* error) const { - if (!supports_bulk_ingest()) { + if (!supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { return ADBC_STATUS_NOT_IMPLEMENTED; } return DoIngestSampleTable(connection, name, error); @@ -247,6 +249,56 @@ void ConnectionTest::TestAutocommitToggle() { //------------------------------------------------------------ // Tests of metadata +std::optional ConnectionGetOption(struct AdbcConnection* connection, + std::string_view option, + struct AdbcError* error) { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + AdbcStatusCode status = + AdbcConnectionGetOption(connection, option.data(), buffer, &buffer_size, error); + EXPECT_THAT(status, IsOkStatus(error)); + if (status != ADBC_STATUS_OK) return std::nullopt; + EXPECT_GT(buffer_size, 0); + if (buffer_size == 0) return std::nullopt; + return std::string(buffer, buffer_size - 1); +} + +void ConnectionTest::TestMetadataCurrentCatalog() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + if (quirks()->supports_metadata_current_catalog()) { + ASSERT_THAT( + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG, &error), + ::testing::Optional(quirks()->catalog())); + } else { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + ASSERT_THAT( + AdbcConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG, + buffer, &buffer_size, &error), + IsStatus(ADBC_STATUS_NOT_FOUND)); + } +} + +void ConnectionTest::TestMetadataCurrentDbSchema() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + if (quirks()->supports_metadata_current_db_schema()) { + ASSERT_THAT(ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, + &error), + ::testing::Optional(quirks()->db_schema())); + } else { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + ASSERT_THAT( + AdbcConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, + buffer, &buffer_size, &error), + IsStatus(ADBC_STATUS_NOT_FOUND)); + } +} + void ConnectionTest::TestMetadataGetInfo() { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); @@ -255,83 +307,110 @@ void ConnectionTest::TestMetadataGetInfo() { GTEST_SKIP(); } - StreamReader reader; - std::vector info = { - ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, - ADBC_INFO_VENDOR_NAME, - ADBC_INFO_VENDOR_VERSION, - }; + for (uint32_t info_code : { + ADBC_INFO_DRIVER_NAME, + ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_DRIVER_ADBC_VERSION, + ADBC_INFO_VENDOR_NAME, + ADBC_INFO_VENDOR_VERSION, + }) { + SCOPED_TRACE("info_code = " + std::to_string(info_code)); + std::optional expected = quirks()->supports_get_sql_info(info_code); - ASSERT_THAT(AdbcConnectionGetInfo(&connection, info.data(), info.size(), - &reader.stream.value, &error), - IsOkStatus(&error)); - ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); - ASSERT_NO_FATAL_FAILURE(CompareSchema( - &reader.schema.value, { - {"info_name", NANOARROW_TYPE_UINT32, NOT_NULL}, - {"info_value", NANOARROW_TYPE_DENSE_UNION, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(reader.schema->children[1], - { - {"string_value", NANOARROW_TYPE_STRING, NULLABLE}, - {"bool_value", NANOARROW_TYPE_BOOL, NULLABLE}, - {"int64_value", NANOARROW_TYPE_INT64, NULLABLE}, - {"int32_bitmask", NANOARROW_TYPE_INT32, NULLABLE}, - {"string_list", NANOARROW_TYPE_LIST, NULLABLE}, - {"int32_to_int32_list_map", NANOARROW_TYPE_MAP, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[4], - { - {"item", NANOARROW_TYPE_STRING, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[5], - { - {"entries", NANOARROW_TYPE_STRUCT, NOT_NULL}, - })); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(reader.schema->children[1]->children[5]->children[0], - { - {"key", NANOARROW_TYPE_INT32, NOT_NULL}, - {"value", NANOARROW_TYPE_LIST, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(reader.schema->children[1]->children[5]->children[0]->children[1], - { - {"item", NANOARROW_TYPE_INT32, NULLABLE}, - })); + if (!expected.has_value()) continue; - std::vector seen; - while (true) { - ASSERT_NO_FATAL_FAILURE(reader.Next()); - if (!reader.array->release) break; + uint32_t info[] = {info_code}; + + StreamReader reader; + ASSERT_THAT(AdbcConnectionGetInfo(&connection, info, 1, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); - for (int64_t row = 0; row < reader.array->length; row++) { - ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row)); - const uint32_t code = - reader.array_view->children[0]->buffer_views[1].data.as_uint32[row]; - seen.push_back(code); - - switch (code) { - case ADBC_INFO_DRIVER_NAME: - case ADBC_INFO_DRIVER_VERSION: - case ADBC_INFO_VENDOR_NAME: - case ADBC_INFO_VENDOR_VERSION: - // UTF8 - ASSERT_EQ(uint8_t(0), - reader.array_view->children[1]->buffer_views[0].data.as_uint8[row]); - default: - // Ignored - break; + ASSERT_NO_FATAL_FAILURE(CompareSchema( + &reader.schema.value, { + {"info_name", NANOARROW_TYPE_UINT32, NOT_NULL}, + {"info_value", NANOARROW_TYPE_DENSE_UNION, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1], + { + {"string_value", NANOARROW_TYPE_STRING, NULLABLE}, + {"bool_value", NANOARROW_TYPE_BOOL, NULLABLE}, + {"int64_value", NANOARROW_TYPE_INT64, NULLABLE}, + {"int32_bitmask", NANOARROW_TYPE_INT32, NULLABLE}, + {"string_list", NANOARROW_TYPE_LIST, NULLABLE}, + {"int32_to_int32_list_map", NANOARROW_TYPE_MAP, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[4], + { + {"item", NANOARROW_TYPE_STRING, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1]->children[5], + { + {"entries", NANOARROW_TYPE_STRUCT, NOT_NULL}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1]->children[5]->children[0], + { + {"key", NANOARROW_TYPE_INT32, NOT_NULL}, + {"value", NANOARROW_TYPE_LIST, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1]->children[5]->children[0]->children[1], + { + {"item", NANOARROW_TYPE_INT32, NULLABLE}, + })); + + std::vector seen; + while (true) { + ASSERT_NO_FATAL_FAILURE(reader.Next()); + if (!reader.array->release) break; + + for (int64_t row = 0; row < reader.array->length; row++) { + ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row)); + const uint32_t code = + reader.array_view->children[0]->buffer_views[1].data.as_uint32[row]; + seen.push_back(code); + if (code != info_code) { + continue; + } + + ASSERT_TRUE(expected.has_value()) << "Got unexpected info code " << code; + + uint8_t type_code = + reader.array_view->children[1]->buffer_views[0].data.as_uint8[row]; + int32_t offset = + reader.array_view->children[1]->buffer_views[1].data.as_int32[row]; + ASSERT_NO_FATAL_FAILURE(std::visit( + [&](auto&& expected_value) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + ASSERT_EQ(uint8_t(2), type_code); + EXPECT_EQ(expected_value, + ArrowArrayViewGetIntUnsafe( + reader.array_view->children[1]->children[2], offset)); + } else if constexpr (std::is_same_v) { + ASSERT_EQ(uint8_t(0), type_code); + struct ArrowStringView view = ArrowArrayViewGetStringUnsafe( + reader.array_view->children[1]->children[0], offset); + EXPECT_THAT(std::string_view(static_cast(view.data), + view.size_bytes), + ::testing::HasSubstr(expected_value)); + } else { + static_assert(!sizeof(T), "not yet implemented"); + } + }, + *expected)) + << "code: " << type_code; } } + EXPECT_THAT(seen, ::testing::IsSupersetOf(info)); } - ASSERT_THAT(seen, ::testing::UnorderedElementsAreArray(info)); } void ConnectionTest::TestMetadataGetTableSchema() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); @@ -911,6 +990,58 @@ void ConnectionTest::TestMetadataGetObjectsPrimaryKey() { ASSERT_EQ(constraint_column_name, "id"); } +void ConnectionTest::TestMetadataGetObjectsCancel() { + if (!quirks()->supports_cancel() || !quirks()->supports_get_objects()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + StreamReader reader; + ASSERT_THAT( + AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_CATALOGS, nullptr, nullptr, + nullptr, nullptr, nullptr, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_THAT(AdbcConnectionCancel(&connection, &error), IsOkStatus(&error)); + + while (true) { + int err = reader.MaybeNext(); + if (err != 0) { + ASSERT_THAT(err, ::testing::AnyOf(0, IsErrno(ECANCELED, &reader.stream.value, + /*ArrowError*/ nullptr))); + } + if (!reader.array->release) break; + } +} + +void ConnectionTest::TestMetadataGetStatisticNames() { + if (!quirks()->supports_statistics()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + StreamReader reader; + ASSERT_THAT(AdbcConnectionGetStatisticNames(&connection, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_NO_FATAL_FAILURE(CompareSchema( + &reader.schema.value, { + {"statistic_name", NANOARROW_TYPE_STRING, NOT_NULL}, + {"statistic_key", NANOARROW_TYPE_INT16, NOT_NULL}, + })); + + while (true) { + ASSERT_NO_FATAL_FAILURE(reader.Next()); + if (!reader.array->release) break; + } +} + //------------------------------------------------------------ // Tests of AdbcStatement @@ -965,7 +1096,7 @@ void StatementTest::TestRelease() { template void StatementTest::TestSqlIngestType(ArrowType type, const std::vector>& values) { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1105,7 +1236,7 @@ void StatementTest::TestSqlIngestDate32() { template void StatementTest::TestSqlIngestTimestampType(const char* timezone) { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1203,7 +1334,7 @@ void StatementTest::TestSqlIngestTimestampTz() { } void StatementTest::TestSqlIngestInterval() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1311,7 +1442,8 @@ void StatementTest::TestSqlIngestTableEscaping() { } void StatementTest::TestSqlIngestAppend() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_APPEND)) { GTEST_SKIP(); } @@ -1389,8 +1521,185 @@ void StatementTest::TestSqlIngestAppend() { ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } +void StatementTest::TestSqlIngestReplace() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_REPLACE)) { + GTEST_SKIP(); + } + + // Ingest + + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_REPLACE, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); + + // Read data back + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(1, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE(CompareArray(reader.array_view->children[0], {42})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } + + // Replace + // Re-initialize since Bind() should take ownership of data + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {-42, -42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_REPLACE, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + + // Read data back + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(2, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE( + CompareArray(reader.array_view->children[0], {-42, -42})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } +} + +void StatementTest::TestSqlIngestCreateAppend() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE_APPEND)) { + GTEST_SKIP(); + } + + ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), + IsOkStatus(&error)); + + // Ingest + + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_CREATE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); + + // Append + // Re-initialize since Bind() should take ownership of data + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42, 42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + + // Read data back + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(3, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE( + CompareArray(reader.array_view->children[0], {42, 42, 42})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + void StatementTest::TestSqlIngestErrors() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1470,7 +1779,7 @@ void StatementTest::TestSqlIngestErrors() { } void StatementTest::TestSqlIngestMultipleConnections() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1540,7 +1849,7 @@ void StatementTest::TestSqlIngestMultipleConnections() { } void StatementTest::TestSqlIngestSample() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1786,7 +2095,7 @@ void StatementTest::TestSqlPrepareSelectParams() { } void StatementTest::TestSqlPrepareUpdate() { - if (!quirks()->supports_bulk_ingest() || + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || !quirks()->supports_dynamic_parameter_binding()) { GTEST_SKIP(); } @@ -1865,7 +2174,7 @@ void StatementTest::TestSqlPrepareUpdateNoParams() { } void StatementTest::TestSqlPrepareUpdateStream() { - if (!quirks()->supports_bulk_ingest() || + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || !quirks()->supports_dynamic_parameter_binding()) { GTEST_SKIP(); } @@ -2140,6 +2449,36 @@ void StatementTest::TestSqlQueryStrings() { ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } +void StatementTest::TestSqlQueryCancel() { + if (!quirks()->supports_cancel()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 'SaShiSuSeSo'", &error), + IsOkStatus(&error)); + + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_THAT(AdbcStatementCancel(&statement, &error), IsOkStatus(&error)); + while (true) { + int err = reader.MaybeNext(); + if (err != 0) { + ASSERT_THAT(err, ::testing::AnyOf(0, IsErrno(ECANCELED, &reader.stream.value, + /*ArrowError*/ nullptr))); + } + if (!reader.array->release) break; + } + } + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + void StatementTest::TestSqlQueryErrors() { // Invalid query ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); @@ -2159,6 +2498,13 @@ void StatementTest::TestTransactions() { ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), IsOkStatus(&error)); + if (quirks()->supports_get_option()) { + auto autocommit = + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, &error); + ASSERT_THAT(autocommit, + ::testing::Optional(::testing::StrEq(ADBC_OPTION_VALUE_ENABLED))); + } + Handle connection2; ASSERT_THAT(AdbcConnectionNew(&connection2.value, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection2.value, &database, &error), @@ -2168,6 +2514,13 @@ void StatementTest::TestTransactions() { ADBC_OPTION_VALUE_DISABLED, &error), IsOkStatus(&error)); + if (quirks()->supports_get_option()) { + auto autocommit = + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, &error); + ASSERT_THAT(autocommit, + ::testing::Optional(::testing::StrEq(ADBC_OPTION_VALUE_DISABLED))); + } + // Uncommitted change ASSERT_NO_FATAL_FAILURE(IngestSampleTable(&connection, &error)); @@ -2243,6 +2596,86 @@ void StatementTest::TestTransactions() { } } +void StatementTest::TestSqlSchemaInts() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("i"), // int32 + ::testing::StrEq("l"), // int64 + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaFloats() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT CAST(1.5 AS FLOAT)", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("f"), // float32 + ::testing::StrEq("g"), // float64 + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaStrings() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 'hi'", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("u"), // string + ::testing::StrEq("U"), // large_string + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaErrors() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + void StatementTest::TestConcurrentStatements() { Handle statement1; Handle statement2; @@ -2278,6 +2711,24 @@ void StatementTest::TestConcurrentStatements() { ASSERT_NO_FATAL_FAILURE(reader1.GetSchema()); } +// Test that an ADBC 1.0.0-sized error still works +void StatementTest::TestErrorCompatibility() { + // XXX: sketchy cast + auto* error = static_cast(malloc(ADBC_ERROR_1_0_0_SIZE)); + std::memset(error, 0, ADBC_ERROR_1_0_0_SIZE); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, error), IsOkStatus(error)); + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "SELECT * FROM thistabledoesnotexist", error), + IsOkStatus(error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, error), + ::testing::Not(IsOkStatus(error))); + error->release(error); + free(error); +} + void StatementTest::TestResultInvalidation() { // Start reading from a statement, then overwrite it ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); @@ -2296,8 +2747,8 @@ void StatementTest::TestResultInvalidation() { IsOkStatus(&error)); ASSERT_NO_FATAL_FAILURE(reader2.GetSchema()); - // First reader should not fail, but may give no data - ASSERT_NO_FATAL_FAILURE(reader1.Next()); + // First reader may fail, or may succeed but give no data + reader1.MaybeNext(); } #undef NOT_NULL diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index 23dacb7f4b..9bedc6a376 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -31,6 +32,8 @@ namespace adbc_validation { #define ADBCV_STRINGIFY(s) #s #define ADBCV_STRINGIFY_VALUE(s) ADBCV_STRINGIFY(s) +using SqlInfoValue = std::variant; + /// \brief Configuration for driver-specific behavior. class DriverQuirks { public: @@ -85,10 +88,22 @@ class DriverQuirks { return ingest_type; } + /// \brief Whether bulk ingest is supported + virtual bool supports_bulk_ingest(const char* mode) const { return true; } + + /// \brief Whether we can cancel queries. + virtual bool supports_cancel() const { return false; } + /// \brief Whether two statements can be used at the same time on a /// single connection virtual bool supports_concurrent_statements() const { return false; } + /// \brief Whether AdbcStatementExecuteSchema should work + virtual bool supports_execute_schema() const { return false; } + + /// \brief Whether GetOption* should work + virtual bool supports_get_option() const { return true; } + /// \brief Whether AdbcStatementExecutePartitions should work virtual bool supports_partitioned_data() const { return false; } @@ -101,11 +116,19 @@ class DriverQuirks { /// \brief Whether GetSqlInfo is implemented virtual bool supports_get_sql_info() const { return true; } + /// \brief The expected value for a given info code + virtual std::optional supports_get_sql_info(uint32_t info_code) const { + return std::nullopt; + } + /// \brief Whether GetObjects is implemented virtual bool supports_get_objects() const { return true; } - /// \brief Whether bulk ingest is supported - virtual bool supports_bulk_ingest() const { return true; } + /// \brief Whether we can get ADBC_CONNECTION_OPTION_CURRENT_CATALOG + virtual bool supports_metadata_current_catalog() const { return false; } + + /// \brief Whether we can get ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA + virtual bool supports_metadata_current_db_schema() const { return false; } /// \brief Whether dynamic parameter bindings are supported for prepare virtual bool supports_dynamic_parameter_binding() const { return true; } @@ -113,6 +136,9 @@ class DriverQuirks { /// \brief Whether ExecuteQuery sets rows_affected appropriately virtual bool supports_rows_affected() const { return true; } + /// \brief Whether we can get statistics + virtual bool supports_statistics() const { return false; } + /// \brief Default catalog to use for tests virtual std::string catalog() const { return ""; } @@ -157,6 +183,9 @@ class ConnectionTest { void TestAutocommitToggle(); + void TestMetadataCurrentCatalog(); + void TestMetadataCurrentDbSchema(); + void TestMetadataGetInfo(); void TestMetadataGetTableSchema(); void TestMetadataGetTableTypes(); @@ -168,6 +197,9 @@ class ConnectionTest { void TestMetadataGetObjectsColumns(); void TestMetadataGetObjectsConstraints(); void TestMetadataGetObjectsPrimaryKey(); + void TestMetadataGetObjectsCancel(); + + void TestMetadataGetStatisticNames(); protected: struct AdbcError error; @@ -175,28 +207,32 @@ class ConnectionTest { struct AdbcConnection connection; }; -#define ADBCV_TEST_CONNECTION(FIXTURE) \ - static_assert(std::is_base_of::value, \ - ADBCV_STRINGIFY(FIXTURE) " must inherit from ConnectionTest"); \ - TEST_F(FIXTURE, NewInit) { TestNewInit(); } \ - TEST_F(FIXTURE, Release) { TestRelease(); } \ - TEST_F(FIXTURE, Concurrent) { TestConcurrent(); } \ - TEST_F(FIXTURE, AutocommitDefault) { TestAutocommitDefault(); } \ - TEST_F(FIXTURE, AutocommitToggle) { TestAutocommitToggle(); } \ - TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); } \ - TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); } \ - TEST_F(FIXTURE, MetadataGetTableTypes) { TestMetadataGetTableTypes(); } \ - TEST_F(FIXTURE, MetadataGetObjectsCatalogs) { TestMetadataGetObjectsCatalogs(); } \ - TEST_F(FIXTURE, MetadataGetObjectsDbSchemas) { TestMetadataGetObjectsDbSchemas(); } \ - TEST_F(FIXTURE, MetadataGetObjectsTables) { TestMetadataGetObjectsTables(); } \ - TEST_F(FIXTURE, MetadataGetObjectsTablesTypes) { \ - TestMetadataGetObjectsTablesTypes(); \ - } \ - TEST_F(FIXTURE, MetadataGetObjectsColumns) { TestMetadataGetObjectsColumns(); } \ - TEST_F(FIXTURE, MetadataGetObjectsConstraints) { \ - TestMetadataGetObjectsConstraints(); \ - } \ - TEST_F(FIXTURE, MetadataGetObjectsPrimaryKey) { TestMetadataGetObjectsPrimaryKey(); } +#define ADBCV_TEST_CONNECTION(FIXTURE) \ + static_assert(std::is_base_of::value, \ + ADBCV_STRINGIFY(FIXTURE) " must inherit from ConnectionTest"); \ + TEST_F(FIXTURE, NewInit) { TestNewInit(); } \ + TEST_F(FIXTURE, Release) { TestRelease(); } \ + TEST_F(FIXTURE, Concurrent) { TestConcurrent(); } \ + TEST_F(FIXTURE, AutocommitDefault) { TestAutocommitDefault(); } \ + TEST_F(FIXTURE, AutocommitToggle) { TestAutocommitToggle(); } \ + TEST_F(FIXTURE, MetadataCurrentCatalog) { TestMetadataCurrentCatalog(); } \ + TEST_F(FIXTURE, MetadataCurrentDbSchema) { TestMetadataCurrentDbSchema(); } \ + TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); } \ + TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); } \ + TEST_F(FIXTURE, MetadataGetTableTypes) { TestMetadataGetTableTypes(); } \ + TEST_F(FIXTURE, MetadataGetObjectsCatalogs) { TestMetadataGetObjectsCatalogs(); } \ + TEST_F(FIXTURE, MetadataGetObjectsDbSchemas) { TestMetadataGetObjectsDbSchemas(); } \ + TEST_F(FIXTURE, MetadataGetObjectsTables) { TestMetadataGetObjectsTables(); } \ + TEST_F(FIXTURE, MetadataGetObjectsTablesTypes) { \ + TestMetadataGetObjectsTablesTypes(); \ + } \ + TEST_F(FIXTURE, MetadataGetObjectsColumns) { TestMetadataGetObjectsColumns(); } \ + TEST_F(FIXTURE, MetadataGetObjectsConstraints) { \ + TestMetadataGetObjectsConstraints(); \ + } \ + TEST_F(FIXTURE, MetadataGetObjectsPrimaryKey) { TestMetadataGetObjectsPrimaryKey(); } \ + TEST_F(FIXTURE, MetadataGetObjectsCancel) { TestMetadataGetObjectsCancel(); } \ + TEST_F(FIXTURE, MetadataGetStatisticNames) { TestMetadataGetStatisticNames(); } class StatementTest { public: @@ -239,6 +275,8 @@ class StatementTest { void TestSqlIngestTableEscaping(); void TestSqlIngestAppend(); + void TestSqlIngestReplace(); + void TestSqlIngestCreateAppend(); void TestSqlIngestErrors(); void TestSqlIngestMultipleConnections(); void TestSqlIngestSample(); @@ -258,11 +296,19 @@ class StatementTest { void TestSqlQueryFloats(); void TestSqlQueryStrings(); + void TestSqlQueryCancel(); void TestSqlQueryErrors(); + void TestSqlSchemaInts(); + void TestSqlSchemaFloats(); + void TestSqlSchemaStrings(); + + void TestSqlSchemaErrors(); + void TestTransactions(); void TestConcurrentStatements(); + void TestErrorCompatibility(); void TestResultInvalidation(); protected: @@ -308,6 +354,8 @@ class StatementTest { TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \ TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); } \ TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \ + TEST_F(FIXTURE, SqlIngestReplace) { TestSqlIngestReplace(); } \ + TEST_F(FIXTURE, SqlIngestCreateAppend) { TestSqlIngestCreateAppend(); } \ TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); } \ TEST_F(FIXTURE, SqlIngestMultipleConnections) { TestSqlIngestMultipleConnections(); } \ TEST_F(FIXTURE, SqlIngestSample) { TestSqlIngestSample(); } \ @@ -325,9 +373,15 @@ class StatementTest { TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \ TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \ TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); } \ + TEST_F(FIXTURE, SqlQueryCancel) { TestSqlQueryCancel(); } \ TEST_F(FIXTURE, SqlQueryErrors) { TestSqlQueryErrors(); } \ + TEST_F(FIXTURE, SqlSchemaInts) { TestSqlSchemaInts(); } \ + TEST_F(FIXTURE, SqlSchemaFloats) { TestSqlSchemaFloats(); } \ + TEST_F(FIXTURE, SqlSchemaStrings) { TestSqlSchemaStrings(); } \ + TEST_F(FIXTURE, SqlSchemaErrors) { TestSqlSchemaErrors(); } \ TEST_F(FIXTURE, Transactions) { TestTransactions(); } \ TEST_F(FIXTURE, ConcurrentStatements) { TestConcurrentStatements(); } \ + TEST_F(FIXTURE, ErrorCompatibility) { TestErrorCompatibility(); } \ TEST_F(FIXTURE, ResultInvalidation) { TestResultInvalidation(); } } // namespace adbc_validation diff --git a/c/validation/adbc_validation_util.h b/c/validation/adbc_validation_util.h index 7c60e3cd6a..5c89fa25f1 100644 --- a/c/validation/adbc_validation_util.h +++ b/c/validation/adbc_validation_util.h @@ -31,6 +31,7 @@ #include #include #include + #include "common/utils.h" namespace adbc_validation { diff --git a/ci/docker/python-wheel-manylinux.dockerfile b/ci/docker/python-wheel-manylinux.dockerfile index 6afe83587d..aeac91e0a2 100644 --- a/ci/docker/python-wheel-manylinux.dockerfile +++ b/ci/docker/python-wheel-manylinux.dockerfile @@ -27,6 +27,6 @@ ARG ARCH RUN yum install -y docker # arm64v8 -> arm64 -RUN wget --no-verbose https://go.dev/dl/go1.19.5.linux-${ARCH/v8/}.tar.gz -RUN tar -C /usr/local -xzf go1.19.5.linux-${ARCH/v8/}.tar.gz +RUN wget --no-verbose https://go.dev/dl/go1.18.10.linux-${ARCH/v8/}.tar.gz +RUN tar -C /usr/local -xzf go1.18.10.linux-${ARCH/v8/}.tar.gz ENV PATH="${PATH}:/usr/local/go/bin" diff --git a/ci/scripts/python_wheel_unix_build.sh b/ci/scripts/python_wheel_unix_build.sh index 3eea17f5c5..fa614bfded 100755 --- a/ci/scripts/python_wheel_unix_build.sh +++ b/ci/scripts/python_wheel_unix_build.sh @@ -35,7 +35,7 @@ function check_visibility { # Filter out Arrow symbols and see if anything remains. # '_init' and '_fini' symbols may or not be present, we don't care. # (note we must ignore the grep exit status when no match is found) - grep ' T ' nm_arrow.log | grep -v -E '(Adbc|\b_init\b|\b_fini\b)' | cat - > visible_symbols.log + grep ' T ' nm_arrow.log | grep -v -E '(Adbc|DriverInit|\b_init\b|\b_fini\b)' | cat - > visible_symbols.log if [[ -f visible_symbols.log && `cat visible_symbols.log | wc -l` -eq 0 ]]; then return 0 diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index bb5dae2d7d..9e12e5423d 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -23,7 +23,7 @@ # - Maven >= 3.3.9 # - JDK >=7 # - gcc >= 4.8 -# - Go >= 1.18 +# - Go >= 1.20 # - Docker # # To reuse build artifacts between runs set ARROW_TMPDIR environment variable to @@ -254,7 +254,7 @@ install_go() { return 0 fi - local version=1.18.8 + local version=1.18.10 show_info "Installing go version ${version}..." local arch="$(uname -m)" @@ -411,7 +411,7 @@ test_cpp() { maybe_setup_conda \ --file ci/conda_env_cpp.txt \ compilers \ - go=1.18 || exit 1 + go=1.20 || exit 1 if [ "${USE_CONDA}" -gt 0 ]; then export CMAKE_PREFIX_PATH="${CONDA_BACKUP_CMAKE_PREFIX_PATH}:${CMAKE_PREFIX_PATH}" @@ -561,7 +561,7 @@ test_go() { # apache/arrow-adbc#517: `go build` calls git. Don't assume system # has git; even if it's there, go_build.sh sets DYLD_LIBRARY_PATH # which can interfere with system git. - maybe_setup_conda compilers git go=1.18 || exit 1 + maybe_setup_conda compilers git go=1.20 || exit 1 if [ "${USE_CONDA}" -gt 0 ]; then # The CMake setup forces RPATH to be the Conda prefix diff --git a/docker-compose.yml b/docker-compose.yml index eea987a106..dd0ef2f53f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -30,19 +30,6 @@ services: command: | /bin/bash -c 'git config --global --add safe.directory /adbc && source /opt/conda/etc/profile.d/conda.sh && mamba create -y -n adbc -c conda-forge go --file /adbc/ci/conda_env_cpp.txt --file /adbc/ci/conda_env_docs.txt --file /adbc/ci/conda_env_python.txt && conda activate adbc && env ADBC_USE_ASAN=0 ADBC_USE_UBSAN=0 /adbc/ci/scripts/cpp_build.sh /adbc /adbc/build && env CGO_ENABLED=1 /adbc/ci/scripts/go_build.sh /adbc /adbc/build && /adbc/ci/scripts/python_build.sh /adbc /adbc/build && /adbc/ci/scripts/r_build.sh /adbc && /adbc/ci/scripts/docs_build.sh /adbc' - golang-sqlite-flightsql: - image: ${REPO}:golang-${GO}-sqlite-flightsql - build: - context: . - cache_from: - - ${REPO}:golang-${GO}-sqlite-flightsql - dockerfile: ci/docker/golang-flightsql-sqlite.dockerfile - args: - GO: ${GO} - ARROW_MAJOR_VERSION: ${ARROW_MAJOR_VERSION} - ports: - - 8080:8080 - ############################ Java JARs ###################################### java-dist: @@ -162,3 +149,16 @@ services: entrypoint: "/init/bootstrap.sh" volumes: - "./ci/scripts/integration/dremio:/init" + + golang-sqlite-flightsql: + image: ${REPO}:golang-${GO}-sqlite-flightsql + build: + context: . + cache_from: + - ${REPO}:golang-${GO}-sqlite-flightsql + dockerfile: ci/docker/golang-flightsql-sqlite.dockerfile + args: + GO: ${GO} + ARROW_MAJOR_VERSION: ${ARROW_MAJOR_VERSION} + ports: + - 8080:8080 diff --git a/docs/source/format/specification.rst b/docs/source/format/specification.rst index e7a44a4198..89fc3597a0 100644 --- a/docs/source/format/specification.rst +++ b/docs/source/format/specification.rst @@ -57,6 +57,26 @@ implementations will support this. - Go: ``OptionKeyAutoCommit`` - Java: ``org.apache.arrow.adbc.core.AdbcConnection#setAutoCommit(boolean)`` +Metadata +-------- + +ADBC exposes a variety of metadata about the database, such as what catalogs, +schemas, and tables exist, the Arrow schema of tables, and so on. + +.. _specification-statistics: + +Statistics +---------- + +.. note:: Since API revision 1.1.0 + +ADBC exposes table/column statistics, such as the (unique) row count, min/max +values, and so on. The goal here is to make ADBC work better in federation +scenarios, where one query engine wants to read Arrow data from another +database. Having statistics available lets the "outer" query planner make +better choices about things like join order, or even decide to skip reading +data entirely. + Statements ========== @@ -84,6 +104,16 @@ frees the user from knowing the right SQL syntax for their database. - Go: ``OptionKeyIngestTargetTable`` - Java: ``org.apache.arrow.adbc.core.AdbcConnection#bulkIngest(String, org.apache.arrow.adbc.core.BulkIngestMode)`` +.. _specification-cancellation: + +Cancellation +------------ + +.. note:: Since API revision 1.1.0 + +Queries (and operations that implicitly represent queries, like fetching +:ref:`specification-statistics`) can be cancelled. + Partitioned Result Sets ----------------------- @@ -97,6 +127,16 @@ machines. - Go: ``Statement.ExecutePartitions`` - Java: ``org.apache.arrow.adbc.core.AdbcStatement#executePartitioned()`` +.. _specification-incremental-execution: + +In principle, a vendor could return the results of partitioned execution as +they are available, instead of all at once. Incremental execution allows +drivers to expose this. When enabled, each call to ``ExecutePartitions`` will +return available endpoints to read instead of blocking to retrieve all +endpoints. + +.. note:: Since API revision 1.1.0 + Lifecycle & Usage ----------------- @@ -135,3 +175,74 @@ Partitioned Execution .. mermaid:: AdbcStatementPartitioned.mmd :caption: This is similar to fetching data in Arrow Flight RPC (by design). See :doc:`"Downloading Data" `. + +Error Handling +============== + +The error handling strategy varies by language. + +In C, most methods take a :cpp:class:`AdbcError`. In Go, most methods return +an error that can be cast to an ``AdbcError``. In Java, most methods raise an +``AdbcException``. + +In all cases, an error contains: + +- A status code, +- An error message, +- An optional vendor code (a vendor-specific status code), +- An optional 5-character "SQLSTATE" code (a SQL-like vendor-specific code). + +.. _specification-rich-error-metadata: + +Rich Error Metadata +------------------- + +.. note:: Since API revision 1.1.0 + +Drivers can expose additional rich error metadata. This can be used to return +structured error information. For example, a driver could use something like +the `Googleapis ErrorDetails`_. + +In C, Go and Java, :cpp:class:`AdbcError`, ``AdbcError``, and +``AdbcException`` respectively expose a list of additional metadata. For C, +see the documentation of :cpp:class:`AdbcError` to learn how the struct was +expanded while preserving ABI. + +.. _Googleapis ErrorDetails: https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto + +Changelog +========= + +Version 1.1.0 +------------- + +The info key ADBC_INFO_DRIVER_ADBC_VERSION can be used to retrieve the +driver's supported ADBC version. + +The canonical options "uri", "username", and "password" were added to make +configuration consistent between drivers. + +:ref:`specification-cancellation` and the ability to both get and set options +of different types were added. (Previously, you could set string options but +could not get option values or get/set values of other types.) This can be +used to get and set the current active catalog and/or schema through a pair of +new canonical options. + +:ref:`specification-bulk-ingestion` supports two additional modes: + +- "adbc.ingest.mode.replace" will drop existing data, then behave like + "create". +- "adbc.ingest.mode.create_append" will behave like "create", except if the + table already exists, it will not error. + +:ref:`specification-rich-error-metadata` has been added, allowing clients to +get additional error metadata. + +The ability to retrive table/column :ref:`statistics +` was added. The goal here is to make ADBC work +better in federation scenarios, where one query engine wants to read Arrow +data from another database. + +:ref:`Incremental execution ` allows +streaming partitions of a result set as they are available instead of blocking +and waiting for query execution to finish before reading results. diff --git a/docs/source/format/versioning.rst b/docs/source/format/versioning.rst index 3205b792e1..b255aeebbe 100644 --- a/docs/source/format/versioning.rst +++ b/docs/source/format/versioning.rst @@ -29,14 +29,19 @@ choices were made: Of course, we can never add/remove/change struct members, and we can never change the signatures of existing functions. -The main point of concern is compatibility of :cpp:class:`AdbcDriver`. +In ADBC 1.1.0, it was decided this would only apply to the "public" +API, and not the driver-internal API (:cpp:class:`AdbcDriver`). New +members were added to this struct in the 1.1.0 revision. +Compatibility is handled as follows: The driver entrypoint, :cpp:type:`AdbcDriverInitFunc`, is given a -version and a pointer to a table of function pointers to initialize. -The type of the table will depend on the version; when a new version -of ADBC is accepted, then a new table of function pointers will be -added. That way, the driver knows the type of the table. If/when we -add a new ADBC version, the following scenarios are possible: +version and a pointer to a table of function pointers to initialize +(the :cpp:class:`AdbcDriver`). The size of the table will depend on +the version; when a new version of ADBC is accepted, then a new table +of function pointers may be expanded. For each version, the driver +knows the expected size of the table, and must not read/write fields +beyond that size. If/when we add a new ADBC version, the following +scenarios are possible: - An updated client application uses an old driver library. The client will pass a `version` field greater than what the driver @@ -46,7 +51,8 @@ add a new ADBC version, the following scenarios are possible: - An old client application uses an updated driver library. The client will pass a ``version`` lower than what the driver recognizes, so the driver can either error, or if it can still - implement the old API contract, initialize the older table. + implement the old API contract, initialize the subset of the table + corresponding to the older version. This approach does not let us change the signatures of existing functions, but we can add new functions and remove existing ones. @@ -64,7 +70,7 @@ backwards-incompatible versions such as 2.0.0, but which still implement the API standard version 1.0.0. Similarly, this documentation describes the ADBC API standard version -1.0.0. If/when a compatible revision is made (e.g. new standard -options are defined), the next version would be 1.1.0. If -incompatible changes are made (e.g. new API functions), the next -version would be 2.0.0. +1.1.0. If/when a compatible revision is made (e.g. new standard +options or API functions are defined), the next version would be +1.2.0. If incompatible changes are made (e.g. changing the signature +or semantics of a function), the next version would be 2.0.0. diff --git a/docs/source/python/api/adbc_driver_manager.rst b/docs/source/python/api/adbc_driver_manager.rst index c0d22b62ec..7023af6ace 100644 --- a/docs/source/python/api/adbc_driver_manager.rst +++ b/docs/source/python/api/adbc_driver_manager.rst @@ -31,9 +31,11 @@ Constants & Enums .. autoclass:: adbc_driver_manager.AdbcStatusCode :members: + :undoc-members: .. autoclass:: adbc_driver_manager.GetObjectsDepth :members: + :undoc-members: .. autoclass:: adbc_driver_manager.ConnectionOptions :members: diff --git a/docs/source/python/recipe/postgresql_create_append_table.py b/docs/source/python/recipe/postgresql_create_append_table.py index 9b0c66f989..a2f6258ce5 100644 --- a/docs/source/python/recipe/postgresql_create_append_table.py +++ b/docs/source/python/recipe/postgresql_create_append_table.py @@ -63,7 +63,7 @@ with conn.cursor() as cur: try: cur.adbc_ingest("example", data, mode="create") - except conn.OperationalError: + except conn.ProgrammingError: pass else: raise RuntimeError("Should have failed!") diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go index 92df909b98..ad6194f240 100644 --- a/go/adbc/adbc.go +++ b/go/adbc/adbc.go @@ -42,11 +42,79 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" ) //go:generate go run golang.org/x/tools/cmd/stringer -type Status -linecomment //go:generate go run golang.org/x/tools/cmd/stringer -type InfoCode -linecomment +// ErrorDetail is additional driver-specific error metadata. +// +// This allows drivers to return custom, structured error information (for +// example, JSON or Protocol Buffers) that can be optionally parsed by +// clients, beyond the standard Error fields, without having to encode it in +// the error message. +type ErrorDetail interface { + // Get an identifier for the detail (e.g. if the metadata comes from an HTTP + // header, the key could be the header name). + // + // This allows clients and drivers to cooperate and provide some idea of what + // to expect in the detail. + Key() string + // Serialize the detail value to a byte array for interoperability with C/C++. + Serialize() ([]byte, error) +} + +// ProtobufErrorDetail is an ErrorDetail backed by a Protobuf message. +type ProtobufErrorDetail struct { + Name string + Message proto.Message +} + +func (d *ProtobufErrorDetail) Key() string { + return d.Name +} + +// Serialize serializes the Protobuf message (wrapped in Any). +func (d *ProtobufErrorDetail) Serialize() ([]byte, error) { + any, err := anypb.New(d.Message) + if err != nil { + return nil, err + } + return proto.Marshal(any) +} + +// ProtobufErrorDetail is an ErrorDetail backed by a human-readable string. +type TextErrorDetail struct { + Name string + Detail string +} + +func (d *TextErrorDetail) Key() string { + return d.Name +} + +// Serialize serializes the Protobuf message (wrapped in Any). +func (d *TextErrorDetail) Serialize() ([]byte, error) { + return []byte(d.Detail), nil +} + +// ProtobufErrorDetail is an ErrorDetail backed by a binary payload. +type BinaryErrorDetail struct { + Name string + Detail []byte +} + +func (d *BinaryErrorDetail) Key() string { + return d.Name +} + +// Serialize serializes the Binary message (wrapped in Any). +func (d *BinaryErrorDetail) Serialize() ([]byte, error) { + return d.Detail, nil +} + // Error is the detailed error for an operation type Error struct { // Msg is a string representing a human readable error message @@ -58,6 +126,8 @@ type Error struct { // SqlState is a SQLSTATE error code, if provided, as defined // by the SQL:2003 standard. If not set, it will be "\0\0\0\0\0" SqlState [5]byte + // Details is an array of additional driver-specific error details. + Details []ErrorDetail } func (e Error) Error() string { @@ -142,20 +212,36 @@ const ( StatusUnauthorized // Unauthorized ) +const ( + AdbcVersion1_0_0 int64 = 1_000_000 + AdbcVersion1_1_0 int64 = 1_001_000 +) + // Canonical option values const ( - OptionValueEnabled = "true" - OptionValueDisabled = "false" - OptionKeyAutoCommit = "adbc.connection.autocommit" - OptionKeyIngestTargetTable = "adbc.ingest.target_table" - OptionKeyIngestMode = "adbc.ingest.mode" - OptionKeyIsolationLevel = "adbc.connection.transaction.isolation_level" - OptionKeyReadOnly = "adbc.connection.readonly" - OptionValueIngestModeCreate = "adbc.ingest.mode.create" - OptionValueIngestModeAppend = "adbc.ingest.mode.append" - OptionKeyURI = "uri" - OptionKeyUsername = "username" - OptionKeyPassword = "password" + OptionValueEnabled = "true" + OptionValueDisabled = "false" + OptionKeyAutoCommit = "adbc.connection.autocommit" + // The current catalog. + OptionKeyCurrentCatalog = "adbc.connection.catalog" + // The current schema. + OptionKeyCurrentDbSchema = "adbc.connection.db_schema" + // Make ExecutePartitions nonblocking. + OptionKeyIncremental = "adbc.statement.exec.incremental" + // Get the progress + OptionKeyProgress = "adbc.statement.exec.progress" + OptionKeyMaxProgress = "adbc.statement.exec.max_progress" + OptionKeyIngestTargetTable = "adbc.ingest.target_table" + OptionKeyIngestMode = "adbc.ingest.mode" + OptionKeyIsolationLevel = "adbc.connection.transaction.isolation_level" + OptionKeyReadOnly = "adbc.connection.readonly" + OptionValueIngestModeCreate = "adbc.ingest.mode.create" + OptionValueIngestModeAppend = "adbc.ingest.mode.append" + OptionValueIngestModeReplace = "adbc.ingest.mode.replace" + OptionValueIngestModeCreateAppend = "adbc.ingest.mode.create_append" + OptionKeyURI = "uri" + OptionKeyUsername = "username" + OptionKeyPassword = "password" ) type OptionIsolationLevel string @@ -170,6 +256,51 @@ const ( LevelLinearizable OptionIsolationLevel = "adbc.connection.transaction.isolation.linearizable" ) +// Standard statistic names and keys. +const ( + // The dictionary-encoded name of the average byte width statistic. + StatisticAverageByteWidthKey = 0 + // The average byte width statistic. The average size in bytes of a row in + // the column. Value type is float64. + // + // For example, this is roughly the average length of a string for a string + // column. + StatisticAverageByteWidthName = "adbc.statistic.byte_width" + // The dictionary-encoded name of the distinct value count statistic. + StatisticDistinctCountKey = 1 + // The distinct value count (NDV) statistic. The number of distinct values in + // the column. Value type is int64 (when not approximate) or float64 (when + // approximate). + StatisticDistinctCountName = "adbc.statistic.distinct_count" + // The dictionary-encoded name of the max byte width statistic. + StatisticMaxByteWidthKey = 2 + // The max byte width statistic. The maximum size in bytes of a row in the + // column. Value type is int64 (when not approximate) or float64 (when + // approximate). + // + // For example, this is the maximum length of a string for a string column. + StatisticMaxByteWidthName = "adbc.statistic.byte_width" + // The dictionary-encoded name of the max value statistic. + StatisticMaxValueKey = 3 + // The max value statistic. Value type is column-dependent. + StatisticMaxValueName = "adbc.statistic.byte_width" + // The dictionary-encoded name of the min value statistic. + StatisticMinValueKey = 4 + // The min value statistic. Value type is column-dependent. + StatisticMinValueName = "adbc.statistic.byte_width" + // The dictionary-encoded name of the null count statistic. + StatisticNullCountKey = 5 + // The null count statistic. The number of values that are null in the + // column. Value type is int64 (when not approximate) or float64 (when + // approximate). + StatisticNullCountName = "adbc.statistic.null_count" + // The dictionary-encoded name of the row count statistic. + StatisticRowCountKey = 6 + // The row count statistic. The number of rows in the column or table. Value + // type is int64 (when not approximate) or float64 (when approximate). + StatisticRowCountName = "adbc.statistic.row_count" +) + // Driver is the entry point for the interface. It is similar to // database/sql.Driver taking a map of keys and values as options // to initialize a Connection to the database. Any common connection @@ -212,6 +343,8 @@ const ( InfoDriverVersion InfoCode = 101 // DriverVersion // The driver Arrow library version (type: utf8) InfoDriverArrowVersion InfoCode = 102 // DriverArrowVersion + // The driver ADBC API version (type: int64) + InfoDriverADBCVersion InfoCode = 103 // DriverADBCVersion ) type ObjectDepth int @@ -275,6 +408,10 @@ type Connection interface { // codes are defined as constants. Codes [0, 10_000) are reserved // for ADBC usage. Drivers/vendors will ignore requests for unrecognized // codes (the row will be omitted from the result). + // + // Since ADBC 1.1.0: the range [500, 1_000) is reserved for "XDBC" + // information, which is the same metadata provided by the same info + // code range in the Arrow Flight SQL GetSqlInfo RPC. GetInfo(ctx context.Context, infoCodes []InfoCode) (array.RecordReader, error) // GetObjects gets a hierarchical view of all catalogs, database schemas, @@ -470,6 +607,9 @@ type Statement interface { // of rows affected if known, otherwise it will be -1. // // This invalidates any prior result sets on this statement. + // + // Since ADBC 1.1.0: releasing the returned RecordReader without + // consuming it fully is equivalent to calling AdbcStatementCancel. ExecuteQuery(context.Context) (array.RecordReader, int64, error) // ExecuteUpdate executes a statement that does not generate a result @@ -534,5 +674,106 @@ type Statement interface { // // If the driver does not support partitioned results, this will return // an error with a StatusNotImplemented code. + // + // When OptionKeyIncremental is set, this should be called + // repeatedly until receiving an empty Partitions. ExecutePartitions(context.Context) (*arrow.Schema, Partitions, int64, error) } + +// ConnectionGetStatistics is a Connection that supports getting +// statistics on data in the database. +// +// Since ADBC API revision 1.1.0. +type ConnectionGetStatistics interface { + // GetStatistics gets statistics about the data distribution of table(s). + // + // The result is an Arrow dataset with the following schema: + // + // Field Name | Field Type + // -------------------------|---------------------------------- + // catalog_name | utf8 + // catalog_db_schemas | list not null + // + // DB_SCHEMA_SCHEMA is a Struct with fields: + // + // Field Name | Field Type + // -------------------------|---------------------------------- + // db_schema_name | utf8 + // db_schema_statistics | list not null + // + // STATISTICS_SCHEMA is a Struct with fields: + // + // Field Name | Field Type | Comments + // -------------------------|----------------------------------| -------- + // table_name | utf8 not null | + // column_name | utf8 | (1) + // statistic_key | int16 not null | (2) + // statistic_value | VALUE_SCHEMA not null | + // statistic_is_approximate | bool not null | (3) + // + // 1. If null, then the statistic applies to the entire table. + // 2. A dictionary-encoded statistic name (although we do not use the Arrow + // dictionary type). Values in [0, 1024) are reserved for ADBC. Other + // values are for implementation-specific statistics. For the definitions + // of predefined statistic types, see the Statistic constants. To get + // driver-specific statistic names, use AdbcConnectionGetStatisticNames. + // 3. If true, then the value is approximate or best-effort. + // + // VALUE_SCHEMA is a dense union with members: + // + // Field Name | Field Type + // -------------------------|---------------------------------- + // int64 | int64 + // uint64 | uint64 + // float64 | float64 + // binary | binary + // + // For the parameters: If nil is passed, then that parameter will not + // be filtered by at all. If an empty string, then only objects without + // that property (ie: catalog or db schema) will be returned. + // + // All non-empty, non-nil strings should be a search pattern (as described + // earlier). + // + // approximate indicates whether to request exact values of statistics, or + // best-effort/cached values. Requesting exact values may be expensive or + // unsupported. + GetStatistics(ctx context.Context, catalog, dbSchema, tableName *string, approximate bool) (array.RecordReader, error) + + // GetStatisticNames gets a list of custom statistic names defined by this driver. + // + // The result is an Arrow dataset with the following schema: + // + // Field Name | Field Type + // ---------------|---------------- + // statistic_name | utf8 not null + // statistic_key | int16 not null + // + GetStatisticNames(ctx context.Context) (array.RecordReader, error) +} + +// StatementExecuteSchema is a Statement that also supports ExecuteSchema. +// +// Since ADBC API revision 1.1.0. +type StatementExecuteSchema interface { + // ExecuteSchema gets the schema of the result set of a query without executing it. + ExecuteSchema(context.Context) (*arrow.Schema, error) +} + +// GetSetOptions is a PostInitOptions that also supports getting and setting option values of different types. +// +// GetOption functions should return an error with StatusNotFound for unsupported options. +// SetOption functions should return an error with StatusNotImplemented for unsupported options. +// +// Since ADBC API revision 1.1.0. +type GetSetOptions interface { + PostInitOptions + + SetOptionBytes(key string, value []byte) error + SetOptionInt(key string, value int64) error + SetOptionDouble(key string, value float64) error + GetOption(key string) (string, error) + GetOptionBytes(key string) ([]byte, error) + GetOptionInt(key string) (int64, error) + GetOptionDouble(key string) (float64, error) +} diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go index e00310cfc3..70ba50182d 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc.go +++ b/go/adbc/driver/flightsql/flightsql_adbc.go @@ -36,7 +36,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" "fmt" "io" "math" @@ -119,20 +118,13 @@ func init() { adbc.InfoDriverName, adbc.InfoDriverVersion, adbc.InfoDriverArrowVersion, + adbc.InfoDriverADBCVersion, adbc.InfoVendorName, adbc.InfoVendorVersion, adbc.InfoVendorArrowVersion, } } -func getTimeoutOptionValue(v string) (time.Duration, error) { - timeout, err := strconv.ParseFloat(v, 64) - if math.IsNaN(timeout) || math.IsInf(timeout, 0) || timeout < 0 { - return 0, errors.New("timeout must be positive and finite") - } - return time.Duration(timeout * float64(time.Second)), err -} - type Driver struct { Alloc memory.Allocator } @@ -164,6 +156,8 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) { db.dialOpts.block = false db.dialOpts.maxMsgSize = 16 * 1024 * 1024 + db.options = make(map[string]string) + return db, db.SetOptions(opts) } @@ -192,6 +186,7 @@ type database struct { timeout timeoutOption dialOpts dbDialOpts enableCookies bool + options map[string]string alloc memory.Allocator } @@ -199,6 +194,10 @@ type database struct { func (d *database) SetOptions(cnOptions map[string]string) error { var tlsConfig tls.Config + for k, v := range cnOptions { + d.options[k] = v + } + mtlsCert := cnOptions[OptionMTLSCertChain] mtlsKey := cnOptions[OptionMTLSPrivateKey] switch { @@ -287,33 +286,24 @@ func (d *database) SetOptions(cnOptions map[string]string) error { var err error if tv, ok := cnOptions[OptionTimeoutFetch]; ok { - if d.timeout.fetchTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutFetch) } if tv, ok := cnOptions[OptionTimeoutQuery]; ok { - if d.timeout.queryTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutQuery, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutQuery, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutQuery) } if tv, ok := cnOptions[OptionTimeoutUpdate]; ok { - if d.timeout.updateTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutUpdate, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutUpdate, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutUpdate) } if val, ok := cnOptions[OptionWithBlock]; ok { @@ -369,7 +359,7 @@ func (d *database) SetOptions(cnOptions map[string]string) error { continue } return adbc.Error{ - Msg: fmt.Sprintf("Unknown database option '%s'", key), + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), Code: adbc.StatusInvalidArgument, } } @@ -377,6 +367,114 @@ func (d *database) SetOptions(cnOptions map[string]string) error { return nil } +func (d *database) GetOption(key string) (string, error) { + switch key { + case OptionTimeoutFetch: + return d.timeout.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return d.timeout.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return d.timeout.updateTimeout.String(), nil + } + if val, ok := d.options[key]; ok { + return val, nil + } + return "", adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + val, err := d.GetOptionDouble(key) + if err != nil { + return 0, err + } + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return d.timeout.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return d.timeout.queryTimeout.Seconds(), nil + case OptionTimeoutUpdate: + return d.timeout.updateTimeout.Seconds(), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) SetOption(key, value string) error { + // We can't change most options post-init + switch key { + case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: + return d.timeout.setTimeoutString(key, value) + } + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + d.hdrs.Set(strings.TrimPrefix(key, OptionRPCCallHeaderPrefix), value) + } + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionInt(key string, value int64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return d.timeout.setTimeout(key, float64(value)) + } + + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return d.timeout.setTimeout(key, value) + } + + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + type timeoutOption struct { grpc.EmptyCallOption @@ -388,6 +486,45 @@ type timeoutOption struct { updateTimeout time.Duration } +func (t *timeoutOption) setTimeout(key string, value float64) error { + if math.IsNaN(value) || math.IsInf(value, 0) || value < 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %f: timeouts must be non-negative and finite", + key, value), + Code: adbc.StatusInvalidArgument, + } + } + + timeout := time.Duration(value * float64(time.Second)) + + switch key { + case OptionTimeoutFetch: + t.fetchTimeout = timeout + case OptionTimeoutQuery: + t.queryTimeout = timeout + case OptionTimeoutUpdate: + t.updateTimeout = timeout + default: + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option '%s'", key), + Code: adbc.StatusNotImplemented, + } + } + return nil +} + +func (t *timeoutOption) setTimeoutString(key string, value string) error { + timeout, err := strconv.ParseFloat(value, 64) + if err != nil { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %s: %s", + key, value, err.Error()), + Code: adbc.StatusInvalidArgument, + } + } + return t.setTimeout(key, timeout) +} + func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, bool) { for _, opt := range callOptions { if to, ok := opt.(timeoutOption); ok { @@ -590,12 +727,10 @@ func getFlightClient(ctx context.Context, loc string, d *database) (*flightsql.C cl.Alloc = d.alloc if d.user != "" || d.pass != "" { - ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user, d.pass) + var header, trailer metadata.MD + ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user, d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout) if err != nil { - return nil, adbc.Error{ - Msg: err.Error(), - Code: adbc.StatusUnauthenticated, - } + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "AuthenticateBasicToken") } if md, ok := metadata.FromOutgoingContext(ctx); ok { @@ -729,52 +864,115 @@ func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEnd return nil, err } -func (c *cnxn) SetOption(key, value string) error { +func (c *cnxn) GetOption(key string) (string, error) { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) - if value == "" { - c.hdrs.Delete(name) - } else { - c.hdrs.Append(name, value) + headers := c.hdrs.Get(name) + if len(headers) > 0 { + return headers[0], nil + } + return "", adbc.Error{ + Msg: "[Flight SQL] unknown header", + Code: adbc.StatusNotFound, } - return nil } switch key { case OptionTimeoutFetch: - timeout, err := getTimeoutOptionValue(value) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } + return c.timeouts.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return c.timeouts.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return c.timeouts.updateTimeout.String(), nil + case adbc.OptionKeyAutoCommit: + if c.txn != nil { + // No autocommit + return adbc.OptionValueDisabled, nil + } else { + // Autocommit + return adbc.OptionValueEnabled, nil + } + case adbc.OptionKeyCurrentCatalog: + return "", adbc.Error{ + Msg: "[Flight SQL] current catalog not supported", + Code: adbc.StatusNotFound, + } + + case adbc.OptionKeyCurrentDbSchema: + return "", adbc.Error{ + Msg: "[Flight SQL] current schema not supported", + Code: adbc.StatusNotFound, } - c.timeouts.fetchTimeout = timeout + } + + return "", adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough case OptionTimeoutQuery: - timeout, err := getTimeoutOptionValue(value) + fallthrough + case OptionTimeoutUpdate: + val, err := c.GetOptionDouble(key) if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } + return 0, err } - c.timeouts.queryTimeout = timeout + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return c.timeouts.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return c.timeouts.queryTimeout.Seconds(), nil case OptionTimeoutUpdate: - timeout, err := getTimeoutOptionValue(value) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } + return c.timeouts.updateTimeout.Seconds(), nil + } + + return 0.0, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) SetOption(key, value string) error { + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) + if value == "" { + c.hdrs.Delete(name) + } else { + c.hdrs.Append(name, value) } - c.timeouts.updateTimeout = timeout + return nil + } + + switch key { + case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: + return c.timeouts.setTimeoutString(key, value) case adbc.OptionKeyAutoCommit: autocommit := true switch value { case adbc.OptionValueEnabled: + // Do nothing case adbc.OptionValueDisabled: autocommit = false default: @@ -823,8 +1021,41 @@ func (c *cnxn) SetOption(key, value string) error { Code: adbc.StatusNotImplemented, } } +} - return nil +func (c *cnxn) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionInt(key string, value int64) error { + switch key { + case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: + return c.timeouts.setTimeout(key, float64(value)) + } + + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return c.timeouts.setTimeout(key, value) + } + + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } } // GetInfo returns metadata about the database/driver. @@ -853,6 +1084,7 @@ func (c *cnxn) SetOption(key, value string) error { // codes (the row will be omitted from the result). func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { const strValTypeID arrow.UnionTypeCode = 0 + const intValTypeID arrow.UnionTypeCode = 2 if len(infoCodes) == 0 { infoCodes = infoSupportedCodes @@ -864,7 +1096,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr := bldr.Field(0).(*array.Uint32Builder) infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder) + strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) + intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) translated := make([]flightsql.SqlInfo, 0, len(infoCodes)) for _, code := range infoCodes { @@ -886,16 +1119,22 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) strInfoBldr.Append(infoDriverArrowVersion) + case adbc.InfoDriverADBCVersion: + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(intValTypeID) + intInfoBldr.Append(adbc.AdbcVersion1_1_0) } } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - info, err := c.cl.GetSqlInfo(ctx, translated, c.timeouts) + var header, trailer metadata.MD + info, err := c.cl.GetSqlInfo(ctx, translated, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err == nil { for i, endpoint := range info.Endpoint { - rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, c.timeouts) + var header, trailer metadata.MD + rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) } for rdr.Next() { @@ -911,6 +1150,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr.Append(uint32(adbc.InfoVendorVersion)) case flightsql.SqlInfoFlightSqlServerArrowVersion: infoNameBldr.Append(uint32(adbc.InfoVendorArrowVersion)) + default: + continue } infoValueBldr.Append(info.TypeCode(i)) @@ -921,8 +1162,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re } } - if rdr.Err() != nil { - return nil, adbcFromFlightStatus(rdr.Err(), "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) + if err := checkContext(rdr.Err(), ctx); err != nil { + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) } } } else if grpcstatus.Code(err) != grpccodes.Unimplemented { @@ -1029,15 +1270,18 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog * } defer g.Release() + var header, trailer metadata.MD // To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response. - info, err := c.cl.GetCatalogs(ctx) + info, err := c.cl.GetCatalogs(ctx, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } - rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info) + header = metadata.MD{} + trailer = metadata.MD{} + rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } defer rdr.Release() @@ -1057,17 +1301,16 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog * g.AppendCatalog("") } - if err = rdr.Err(); err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)") + if err := checkContext(rdr.Err(), ctx); err != nil { + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } - return g.Finish() } // Helper function to read and validate a metadata stream -func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info *flight.FlightInfo) (array.RecordReader, error) { +func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info *flight.FlightInfo, opts ...grpc.CallOption) (array.RecordReader, error) { // use a default queueSize for the reader - rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5) + rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5, opts...) if err != nil { return nil, adbcFromFlightStatus(err, "DoGet") } @@ -1088,15 +1331,18 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, return } result = make(map[string][]string) + var header, trailer metadata.MD // Pre-populate the map of which schemas are in which catalogs - info, err := c.cl.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{DbSchemaFilterPattern: dbSchema}) + info, err := c.cl.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{DbSchemaFilterPattern: dbSchema}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetDBSchemas)") } - rdr, err := c.readInfo(ctx, schema_ref.DBSchemas, info) + header = metadata.MD{} + trailer = metadata.MD{} + rdr, err := c.readInfo(ctx, schema_ref.DBSchemas, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetDBSchemas)") } defer rdr.Release() @@ -1115,9 +1361,8 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, } } - if rdr.Err() != nil { - result = nil - err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetDBSchemas)") + if err := checkContext(rdr.Err(), ctx); err != nil { + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } return } @@ -1130,21 +1375,24 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat // Pre-populate the map of which schemas are in which catalogs includeSchema := depth == adbc.ObjectDepthAll || depth == adbc.ObjectDepthColumns + var header, trailer metadata.MD info, err := c.cl.GetTables(ctx, &flightsql.GetTablesOpts{ DbSchemaFilterPattern: dbSchema, TableNameFilterPattern: tableName, TableTypes: tableType, IncludeSchema: includeSchema, - }) + }, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetTables)") } expectedSchema := schema_ref.Tables if includeSchema { expectedSchema = schema_ref.TablesWithIncludedSchema } - rdr, err := c.readInfo(ctx, expectedSchema, info) + header = metadata.MD{} + trailer = metadata.MD{} + rdr, err := c.readInfo(ctx, expectedSchema, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)") } @@ -1193,9 +1441,8 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat } } - if rdr.Err() != nil { - result = nil - err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetTables)") + if err := checkContext(rdr.Err(), ctx); err != nil { + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetTables)") } return } @@ -1209,14 +1456,17 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - info, err := c.cl.GetTables(ctx, opts, c.timeouts) + var header, trailer metadata.MD + info, err := c.cl.GetTables(ctx, opts, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return nil, adbcFromFlightStatus(err, "GetTableSchema(GetTables)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableSchema(GetTables)") } - rdr, err := doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts) + header = metadata.MD{} + trailer = metadata.MD{} + rdr, err := doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableSchema(DoGet)") } defer rdr.Release() @@ -1228,7 +1478,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st Code: adbc.StatusNotFound, } } - return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableSchema(DoGet)") } numRows := rec.NumRows() @@ -1278,9 +1528,10 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st // table_type | utf8 not null func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - info, err := c.cl.GetTableTypes(ctx, c.timeouts) + var header, trailer metadata.MD + info, err := c.cl.GetTableTypes(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return nil, adbcFromFlightStatus(err, "GetTableTypes") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableTypes") } return newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5) @@ -1305,14 +1556,17 @@ func (c *cnxn) Commit(ctx context.Context) error { } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - err := c.txn.Commit(ctx, c.timeouts) + var header, trailer metadata.MD + err := c.txn.Commit(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return adbcFromFlightStatus(err, "Commit") + return adbcFromFlightStatusWithDetails(err, header, trailer, "Commit") } - c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts) + header = metadata.MD{} + trailer = metadata.MD{} + c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return adbcFromFlightStatus(err, "BeginTransaction") + return adbcFromFlightStatusWithDetails(err, header, trailer, "BeginTransaction") } return nil } @@ -1336,14 +1590,17 @@ func (c *cnxn) Rollback(ctx context.Context) error { } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - err := c.txn.Rollback(ctx, c.timeouts) + var header, trailer metadata.MD + err := c.txn.Rollback(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return adbcFromFlightStatus(err, "Rollback") + return adbcFromFlightStatusWithDetails(err, header, trailer, "Rollback") } - c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts) + header = metadata.MD{} + trailer = metadata.MD{} + c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return adbcFromFlightStatus(err, "BeginTransaction") + return adbcFromFlightStatusWithDetails(err, header, trailer, "BeginTransaction") } return nil } @@ -1368,6 +1625,14 @@ func (c *cnxn) execute(ctx context.Context, query string, opts ...grpc.CallOptio return c.cl.Execute(ctx, query, opts...) } +func (c *cnxn) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + if c.txn != nil { + return c.txn.GetExecuteSchema(ctx, query, opts...) + } + + return c.cl.GetExecuteSchema(ctx, query, opts...) +} + func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if c.txn != nil { return c.txn.ExecuteSubstrait(ctx, plan, opts...) @@ -1376,6 +1641,14 @@ func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPla return c.cl.ExecuteSubstrait(ctx, plan, opts...) } +func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + if c.txn != nil { + return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...) + } + + return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...) +} + func (c *cnxn) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) { if c.txn != nil { return c.txn.ExecuteUpdate(ctx, query, opts...) @@ -1419,7 +1692,7 @@ func (c *cnxn) Close() error { err := c.cl.Close() c.cl = nil - return err + return adbcFromFlightStatus(err, "Close") } // ReadPartition constructs a statement for a partition of a query. The diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index d8af6a6579..d43b9fd6aa 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -43,6 +43,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" ) // ---- Common Infra -------------------- @@ -96,6 +97,14 @@ func TestAuthn(t *testing.T) { suite.Run(t, &AuthnTests{}) } +func TestErrorDetails(t *testing.T) { + suite.Run(t, &ErrorDetailsTests{}) +} + +func TestExecuteSchema(t *testing.T) { + suite.Run(t, &ExecuteSchemaTests{}) +} + func TestTimeout(t *testing.T) { suite.Run(t, &TimeoutTests{}) } @@ -211,6 +220,204 @@ func (suite *AuthnTests) TestBearerTokenUpdated() { defer reader.Release() } +// ---- Error Details Tests -------------------- + +type ErrorDetailsTestServer struct { + flightsql.BaseServer +} + +func (srv *ErrorDetailsTestServer) GetFlightInfoStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + if query.GetQuery() == "details" { + detail := wrapperspb.Int32Value{Value: 42} + st, err := status.New(codes.Unknown, "details").WithDetails(&detail) + if err != nil { + return nil, err + } + return nil, st.Err() + } else if query.GetQuery() == "query" { + tkt, err := flightsql.CreateStatementQueryTicket([]byte("fetch")) + if err != nil { + panic(err) + } + return &flight.FlightInfo{Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}}}, nil + } + return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") +} + +func (ts *ErrorDetailsTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { + sc := arrow.NewSchema([]arrow.Field{}, nil) + detail := wrapperspb.Int32Value{Value: 42} + st, err := status.New(codes.Unknown, "details").WithDetails(&detail) + if err != nil { + return nil, nil, err + } + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: nil, + Desc: nil, + Err: st.Err(), + } + }() + return sc, ch, nil +} + +type ErrorDetailsTests struct { + ServerBasedTests +} + +func (suite *ErrorDetailsTests) SetupSuite() { + srv := ErrorDetailsTestServer{} + srv.Alloc = memory.DefaultAllocator + suite.DoSetupSuite(&srv, nil, nil) +} + +func (ts *ErrorDetailsTests) TestGetFlightInfo() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("details")) + + _, _, err = stmt.ExecuteQuery(context.Background()) + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + + ts.Equal(1, len(adbcErr.Details)) + + wrapper, ok := adbcErr.Details[0].(*adbc.ProtobufErrorDetail) + ts.True(ok, "Got message: %#v", wrapper) + ts.Equal("grpc-status-details-bin", wrapper.Key()) + + message, ok := wrapper.Message.(*wrapperspb.Int32Value) + ts.True(ok, "Got message: %#v", message) + ts.Equal(int32(42), message.Value) +} + +func (ts *ErrorDetailsTests) TestDoGet() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("query")) + + reader, _, err := stmt.ExecuteQuery(context.Background()) + ts.NoError(err) + + defer reader.Release() + + for reader.Next() { + } + err = reader.Err() + + ts.Error(err) + + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr, "Error was: %#v", err) + + ts.Equal(1, len(adbcErr.Details)) + + wrapper, ok := adbcErr.Details[0].(*adbc.ProtobufErrorDetail) + ts.True(ok, "Got message: %#v", wrapper) + ts.Equal("grpc-status-details-bin", wrapper.Key()) + + message, ok := wrapper.Message.(*wrapperspb.Int32Value) + ts.True(ok, "Got message: %#v", message) + ts.Equal(int32(42), message.Value) +} + +// ---- ExecuteSchema Tests -------------------- + +type ExecuteSchemaTestServer struct { + flightsql.BaseServer +} + +func (srv *ExecuteSchemaTestServer) GetSchemaStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) { + if query.GetQuery() == "sample query" { + return &flight.SchemaResult{ + Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil), srv.Alloc), + }, nil + } + return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") +} + +func (srv *ExecuteSchemaTestServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (res flightsql.ActionCreatePreparedStatementResult, err error) { + if req.GetQuery() == "sample query" { + return flightsql.ActionCreatePreparedStatementResult{ + DatasetSchema: arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil), + }, nil + } + return flightsql.ActionCreatePreparedStatementResult{}, status.Error(codes.Unimplemented, "CreatePreparedStatement not implemented") +} + +type ExecuteSchemaTests struct { + ServerBasedTests +} + +func (suite *ExecuteSchemaTests) SetupSuite() { + srv := ExecuteSchemaTestServer{} + srv.Alloc = memory.DefaultAllocator + suite.DoSetupSuite(&srv, nil, nil) +} + +func (ts *ExecuteSchemaTests) TestNoQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + es := stmt.(adbc.StatementExecuteSchema) + _, err = es.ExecuteSchema(context.Background()) + + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + ts.Equal(adbc.StatusInvalidState, adbcErr.Code, adbcErr.Error()) +} + +func (ts *ExecuteSchemaTests) TestPreparedQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("sample query")) + ts.NoError(stmt.Prepare(context.Background())) + + es := stmt.(adbc.StatementExecuteSchema) + schema, err := es.ExecuteSchema(context.Background()) + ts.NoError(err) + ts.NotNil(schema) + + expectedSchema := arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil) + + ts.True(expectedSchema.Equal(schema), schema.String()) +} + +func (ts *ExecuteSchemaTests) TestQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("sample query")) + + es := stmt.(adbc.StatementExecuteSchema) + schema, err := es.ExecuteSchema(context.Background()) + ts.NoError(err) + ts.NotNil(schema) + + expectedSchema := arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil) + + ts.True(expectedSchema.Equal(schema), schema.String()) +} + // ---- Timeout Tests -------------------- type TimeoutTestServer struct { @@ -330,6 +537,67 @@ func (ts *TimeoutTests) TestRemoveTimeout() { } } +func (ts *TimeoutTests) TestGetSet() { + keys := []string{ + "adbc.flight.sql.rpc.timeout_seconds.fetch", + "adbc.flight.sql.rpc.timeout_seconds.query", + "adbc.flight.sql.rpc.timeout_seconds.update", + } + stmt, err := ts.cnxn.NewStatement() + ts.Require().NoError(err) + defer stmt.Close() + + for _, v := range []interface{}{ts.db, ts.cnxn, stmt} { + getset := v.(adbc.GetSetOptions) + + for _, k := range keys { + strval, err := getset.GetOption(k) + ts.NoError(err) + ts.Equal("0s", strval) + + intval, err := getset.GetOptionInt(k) + ts.NoError(err) + ts.Equal(int64(0), intval) + + floatval, err := getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(0.0, floatval) + + err = getset.SetOptionInt(k, 1) + ts.NoError(err) + + strval, err = getset.GetOption(k) + ts.NoError(err) + ts.Equal("1s", strval) + + intval, err = getset.GetOptionInt(k) + ts.NoError(err) + ts.Equal(int64(1), intval) + + floatval, err = getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(1.0, floatval) + + err = getset.SetOptionDouble(k, 0.1) + ts.NoError(err) + + strval, err = getset.GetOption(k) + ts.NoError(err) + ts.Equal("100ms", strval) + + intval, err = getset.GetOptionInt(k) + ts.NoError(err) + // truncated + ts.Equal(int64(0), intval) + + floatval, err = getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(0.1, floatval) + } + } + +} + func (ts *TimeoutTests) TestDoActionTimeout() { ts.NoError(ts.cnxn.(adbc.PostInitOptions). SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "0.1")) diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go index ee7f92802e..9b434593cb 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go @@ -229,14 +229,20 @@ func (s *FlightSQLQuirks) DropTable(cnxn adbc.Connection, tblname string) error return err } -func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem } -func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" } -func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true } +func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem } +func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" } +func (s *FlightSQLQuirks) SupportsBulkIngest(string) bool { return false } +func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true } +func (s *FlightSQLQuirks) SupportsCurrentCatalogSchema() bool { return false } + +// The driver supports it, but the server we use for testing does not. +func (s *FlightSQLQuirks) SupportsExecuteSchema() bool { return false } +func (s *FlightSQLQuirks) SupportsGetSetOptions() bool { return true } func (s *FlightSQLQuirks) SupportsPartitionedData() bool { return true } +func (s *FlightSQLQuirks) SupportsStatistics() bool { return false } func (s *FlightSQLQuirks) SupportsTransactions() bool { return true } func (s *FlightSQLQuirks) SupportsGetParameterSchema() bool { return false } func (s *FlightSQLQuirks) SupportsDynamicParameterBinding() bool { return true } -func (s *FlightSQLQuirks) SupportsBulkIngest() bool { return false } func (s *FlightSQLQuirks) GetMetadata(code adbc.InfoCode) interface{} { switch code { case adbc.InfoDriverName: @@ -247,6 +253,8 @@ func (s *FlightSQLQuirks) GetMetadata(code adbc.InfoCode) interface{} { return "(unknown or development build)" case adbc.InfoDriverArrowVersion: return "(unknown or development build)" + case adbc.InfoDriverADBCVersion: + return adbc.AdbcVersion1_1_0 case adbc.InfoVendorName: return "db_name" case adbc.InfoVendorVersion: @@ -273,6 +281,7 @@ func (s *FlightSQLQuirks) SampleTableSchemaMetadata(tblName string, dt arrow.Dat } } +func (s *FlightSQLQuirks) Catalog() string { return "" } func (s *FlightSQLQuirks) DBSchema() string { return "" } func TestADBCFlightSQL(t *testing.T) { diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index 3e7d20e1c4..1dc7074473 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -73,6 +73,29 @@ func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts ...grpc.C } } +func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*arrow.Schema, error) { + var ( + res *flight.SchemaResult + err error + ) + if s.sqlQuery != "" { + res, err = cnxn.executeSchema(ctx, s.sqlQuery, opts...) + } else if s.substraitPlan != nil { + res, err = cnxn.executeSubstraitSchema(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) + } else { + return nil, adbc.Error{ + Code: adbc.StatusInvalidState, + Msg: "[Flight SQL Statement] cannot call ExecuteQuery without a query or prepared statement", + } + } + + if err != nil { + return nil, err + } + + return flight.DeserializeSchema(res.Schema, cnxn.cl.Alloc) +} + func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (int64, error) { if s.sqlQuery != "" { return cnxn.executeUpdate(ctx, s.sqlQuery, opts...) @@ -112,7 +135,9 @@ type statement struct { } func (s *statement) closePreparedStatement() error { - return s.prepared.Close(metadata.NewOutgoingContext(context.Background(), s.hdrs), s.timeouts) + var header, trailer metadata.MD + err := s.prepared.Close(metadata.NewOutgoingContext(context.Background(), s.hdrs), grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) + return adbcFromFlightStatusWithDetails(err, header, trailer, "ClosePreparedStatement") } // Close releases any relevant resources associated with this statement @@ -138,6 +163,72 @@ func (s *statement) Close() (err error) { return err } +func (s *statement) GetOption(key string) (string, error) { + switch key { + case OptionStatementSubstraitVersion: + return s.query.substraitVersion, nil + case OptionTimeoutFetch: + return s.timeouts.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return s.timeouts.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return s.timeouts.updateTimeout.String(), nil + } + + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) + values := s.hdrs.Get(name) + if len(values) > 0 { + return values[0], nil + } + } + + return "", adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + val, err := s.GetOptionDouble(key) + if err != nil { + return 0, err + } + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return s.timeouts.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return s.timeouts.queryTimeout.Seconds(), nil + case OptionTimeoutUpdate: + return s.timeouts.updateTimeout.Seconds(), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} + // SetOption sets a string option on this statement func (s *statement) SetOption(key string, val string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { @@ -152,35 +243,11 @@ func (s *statement) SetOption(key string, val string) error { switch key { case OptionTimeoutFetch: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.fetchTimeout = timeout + fallthrough case OptionTimeoutQuery: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.queryTimeout = timeout + fallthrough case OptionTimeoutUpdate: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.updateTimeout = timeout + return s.timeouts.setTimeoutString(key, val) case OptionStatementQueueSize: var err error var size int @@ -189,13 +256,8 @@ func (s *statement) SetOption(key string, val string) error { Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), Code: adbc.StatusInvalidArgument, } - } else if size <= 0 { - return adbc.Error{ - Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), - Code: adbc.StatusInvalidArgument, - } } - s.queueSize = size + return s.SetOptionInt(key, int64(size)) case OptionStatementSubstraitVersion: s.query.substraitVersion = val default: @@ -207,6 +269,43 @@ func (s *statement) SetOption(key string, val string) error { return nil } +func (s *statement) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (s *statement) SetOptionInt(key string, value int64) error { + switch key { + case OptionStatementQueueSize: + if value <= 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value), + Code: adbc.StatusInvalidArgument, + } + } + s.queueSize = int(value) + return nil + } + return s.SetOptionDouble(key, float64(value)) +} + +func (s *statement) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return s.timeouts.setTimeout(key, value) + } + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. @@ -232,14 +331,16 @@ func (s *statement) SetSqlQuery(query string) error { func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, nrec int64, err error) { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) var info *flight.FlightInfo + var header, trailer metadata.MD + opts := append([]grpc.CallOption{}, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if s.prepared != nil { - info, err = s.prepared.Execute(ctx, s.timeouts) + info, err = s.prepared.Execute(ctx, opts...) } else { - info, err = s.query.execute(ctx, s.cnxn, s.timeouts) + info, err = s.query.execute(ctx, s.cnxn, opts...) } if err != nil { - return nil, -1, adbcFromFlightStatus(err, "ExecuteQuery") + return nil, -1, adbcFromFlightStatusWithDetails(err, header, trailer, "ExecuteQuery") } nrec = info.TotalRecords @@ -251,15 +352,16 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n // set. It returns the number of rows affected if known, otherwise -1. func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) - + var header, trailer metadata.MD + opts := append([]grpc.CallOption{}, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if s.prepared != nil { - n, err = s.prepared.ExecuteUpdate(ctx, s.timeouts) + n, err = s.prepared.ExecuteUpdate(ctx, opts...) } else { - n, err = s.query.executeUpdate(ctx, s.cnxn, s.timeouts) + n, err = s.query.executeUpdate(ctx, s.cnxn, opts...) } if err != nil { - err = adbcFromFlightStatus(err, "ExecuteUpdate") + err = adbcFromFlightStatusWithDetails(err, header, trailer, "ExecuteQuery") } return @@ -269,9 +371,10 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) { // multiple times. This invalidates any prior result sets. func (s *statement) Prepare(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) - prep, err := s.query.prepare(ctx, s.cnxn, s.timeouts) + var header, trailer metadata.MD + prep, err := s.query.prepare(ctx, s.cnxn, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if err != nil { - return adbcFromFlightStatus(err, "Prepare") + return adbcFromFlightStatusWithDetails(err, header, trailer, "Prepare") } s.prepared = prep return nil @@ -387,14 +490,15 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc. err error ) + var header, trailer metadata.MD if s.prepared != nil { - info, err = s.prepared.Execute(ctx, s.timeouts) + info, err = s.prepared.Execute(ctx, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) } else { - info, err = s.query.execute(ctx, s.cnxn, s.timeouts) + info, err = s.query.execute(ctx, s.cnxn, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) } if err != nil { - return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions") + return nil, out, -1, adbcFromFlightStatusWithDetails(err, header, trailer, "ExecutePartitions") } if len(info.Schema) > 0 { @@ -422,3 +526,26 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc. return sc, out, info.TotalRecords, nil } + +// ExecuteSchema gets the schema of the result set of a query without executing it. +func (s *statement) ExecuteSchema(ctx context.Context) (schema *arrow.Schema, err error) { + ctx = metadata.NewOutgoingContext(ctx, s.hdrs) + + if s.prepared != nil { + schema = s.prepared.DatasetSchema() + if schema == nil { + err = adbc.Error{ + Msg: "[Flight SQL Statement] Database server did not provide schema for prepared statement", + Code: adbc.StatusNotImplemented, + } + } + return + } + + var header, trailer metadata.MD + schema, err = s.query.executeSchema(ctx, s.cnxn, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) + if err != nil { + err = adbcFromFlightStatusWithDetails(err, header, trailer, "ExecuteSchema") + } + return +} diff --git a/go/adbc/driver/flightsql/record_reader.go b/go/adbc/driver/flightsql/record_reader.go index c2721a7af3..e505895c42 100644 --- a/go/adbc/driver/flightsql/record_reader.go +++ b/go/adbc/driver/flightsql/record_reader.go @@ -32,6 +32,7 @@ import ( "github.com/bluele/gcache" "golang.org/x/sync/errgroup" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) type reader struct { @@ -49,6 +50,8 @@ type reader struct { // gathers all of the records as they come in. func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.Client, info *flight.FlightInfo, clCache gcache.Cache, bufferSize int, opts ...grpc.CallOption) (rdr array.RecordReader, err error) { endpoints := info.Endpoint + var header, trailer metadata.MD + opts = append(append([]grpc.CallOption{}, opts...), grpc.Header(&header), grpc.Trailer(&trailer)) var schema *arrow.Schema if len(endpoints) == 0 { if info.Schema == nil { @@ -88,9 +91,10 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. Code: adbc.StatusInvalidState} } } else { - rdr, err := doGet(ctx, cl, endpoints[0], clCache, opts...) + firstEndpoint := endpoints[0] + rdr, err := doGet(ctx, cl, firstEndpoint, clCache, opts...) if err != nil { - return nil, adbcFromFlightStatus(err, "DoGet: endpoint 0: remote: %s", endpoints[0].Location) + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "DoGet: endpoint 0: remote: %s", firstEndpoint.Location) } schema = rdr.Schema() group.Go(func() error { @@ -104,7 +108,10 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. rec.Retain() ch <- rec } - return rdr.Err() + if err := checkContext(rdr.Err(), ctx); err != nil { + return adbcFromFlightStatusWithDetails(err, header, trailer, "DoGet: endpoint 0: remote: %s", firstEndpoint.Location) + } + return nil }) endpoints = endpoints[1:] @@ -135,7 +142,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. rdr, err := doGet(ctx, cl, endpoint, clCache, opts...) if err != nil { - return adbcFromFlightStatus(err, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location) + return adbcFromFlightStatusWithDetails(err, header, trailer, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location) } defer rdr.Release() @@ -150,7 +157,10 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. chs[endpointIndex] <- rec } - return rdr.Err() + if err := checkContext(rdr.Err(), ctx); err != nil { + return adbcFromFlightStatusWithDetails(err, header, trailer, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location) + } + return nil }) } diff --git a/go/adbc/driver/flightsql/utils.go b/go/adbc/driver/flightsql/utils.go index e4cf276807..d0d1af850c 100644 --- a/go/adbc/driver/flightsql/utils.go +++ b/go/adbc/driver/flightsql/utils.go @@ -18,14 +18,22 @@ package flightsql import ( + "context" "fmt" "github.com/apache/arrow-adbc/go/adbc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" ) func adbcFromFlightStatus(err error, context string, args ...any) error { + var header, trailer metadata.MD + return adbcFromFlightStatusWithDetails(err, header, trailer, context, args...) +} + +func adbcFromFlightStatusWithDetails(err error, header, trailer metadata.MD, context string, args ...any) error { if _, ok := err.(adbc.Error); ok { return err } @@ -72,9 +80,58 @@ func adbcFromFlightStatus(err error, context string, args ...any) error { adbcCode = adbc.StatusUnknown } - // People don't read error messages, so backload the context and frontload the server error + details := []adbc.ErrorDetail{} + // slice of proto.Message or error + for _, detail := range grpcStatus.Details() { + if err, ok := detail.(error); ok { + details = append(details, &adbc.TextErrorDetail{Name: "grpc-status-details-bin", Detail: err.Error()}) + } else if msg, ok := detail.(proto.Message); ok { + details = append(details, &adbc.ProtobufErrorDetail{Name: "grpc-status-details-bin", Message: msg}) + } + // else, gRPC returned non-Protobuf detail in violation of their method contract + } + + // XXX(https://github.com/grpc/grpc-go/issues/5485): don't count on + // grpc-status-details-bin since Google hardcodes it to only work with + // Google Cloud + // XXX: must check both headers and trailers because some implementations + // (like gRPC-Java) will consolidate trailers into headers for failed RPCs + for key, values := range header { + switch key { + case "content-type", "grpc-status-details-bin": + continue + default: + for _, value := range values { + details = append(details, &adbc.TextErrorDetail{Name: key, Detail: value}) + } + } + } + for key, values := range trailer { + switch key { + case "content-type", "grpc-status-details-bin": + continue + default: + for _, value := range values { + details = append(details, &adbc.TextErrorDetail{Name: key, Detail: value}) + } + } + } + return adbc.Error{ - Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)), - Code: adbcCode, + // People don't read error messages, so backload the context and frontload the server error + Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)), + Code: adbcCode, + Details: details, + } +} + +func checkContext(maybeErr error, ctx context.Context) error { + if maybeErr != nil { + return maybeErr + } else if ctx.Err() == context.Canceled { + return adbc.Error{Msg: "Cancelled by request", Code: adbc.StatusCancelled} + } else if ctx.Err() == context.DeadlineExceeded { + return adbc.Error{Msg: "Deadline exceeded", Code: adbc.StatusTimeout} } + return ctx.Err() } diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index b08b94a7bc..628ac85f86 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -22,6 +22,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "io" "strconv" "strings" "time" @@ -100,6 +101,7 @@ type cnxn struct { // codes (the row will be omitted from the result). func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { const strValTypeID arrow.UnionTypeCode = 0 + const intValTypeID arrow.UnionTypeCode = 2 if len(infoCodes) == 0 { infoCodes = infoSupportedCodes @@ -111,7 +113,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr := bldr.Field(0).(*array.Uint32Builder) infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder) + strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) + intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) for _, code := range infoCodes { switch code { @@ -127,6 +130,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) strInfoBldr.Append(infoDriverArrowVersion) + case adbc.InfoDriverADBCVersion: + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(intValTypeID) + intInfoBldr.Append(adbc.AdbcVersion1_1_0) case adbc.InfoVendorName: infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) @@ -679,6 +686,85 @@ func descToField(name, typ, isnull, primary string, comment sql.NullString) (fie return } +func (c *cnxn) GetOption(key string) (string, error) { + switch key { + case adbc.OptionKeyAutoCommit: + if c.activeTransaction { + // No autocommit + return adbc.OptionValueDisabled, nil + } else { + // Autocommit + return adbc.OptionValueEnabled, nil + } + case adbc.OptionKeyCurrentCatalog: + return c.getStringQuery("SELECT CURRENT_DATABASE()") + case adbc.OptionKeyCurrentDbSchema: + return c.getStringQuery("SELECT CURRENT_SCHEMA()") + } + + return "", adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) getStringQuery(query string) (string, error) { + result, err := c.cn.QueryContext(context.Background(), query, nil) + if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + defer result.Close() + + if len(result.Columns()) != 1 { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Internal query returned wrong number of columns: %s", result.Columns()), + Code: adbc.StatusInternal, + } + } + + dest := make([]driver.Value, 1) + err = result.Next(dest) + if err == io.EOF { + return "", adbc.Error{ + Msg: "[Snowflake] Internal query returned no rows", + Code: adbc.StatusInternal, + } + } else if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + + value, ok := dest[0].(string) + if !ok { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Internal query returned wrong type of value: %s", dest[0]), + Code: adbc.StatusInternal, + } + } + + return value, nil +} + +func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionInt(key string) (int64, error) { + return 0, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionDouble(key string) (float64, error) { + return 0.0, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { tblParts := make([]string, 0, 3) if catalog != nil { @@ -847,6 +933,12 @@ func (c *cnxn) SetOption(key, value string) error { Code: adbc.StatusInvalidArgument, } } + case adbc.OptionKeyCurrentCatalog: + _, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?", []driver.NamedValue{{Value: value}}) + return err + case adbc.OptionKeyCurrentDbSchema: + _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?", []driver.NamedValue{{Value: value}}) + return err default: return adbc.Error{ Msg: "[Snowflake] unknown connection option " + key + ": " + value, @@ -854,3 +946,24 @@ func (c *cnxn) SetOption(key, value string) error { } } } + +func (c *cnxn) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionInt(key string, value int64) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go index c02b58ddec..a00513817b 100644 --- a/go/adbc/driver/snowflake/driver.go +++ b/go/adbc/driver/snowflake/driver.go @@ -209,6 +209,105 @@ type database struct { alloc memory.Allocator } +func (d *database) GetOption(key string) (string, error) { + switch key { + case adbc.OptionKeyUsername: + return d.cfg.User, nil + case adbc.OptionKeyPassword: + return d.cfg.Password, nil + case OptionDatabase: + return d.cfg.Database, nil + case OptionSchema: + return d.cfg.Schema, nil + case OptionWarehouse: + return d.cfg.Warehouse, nil + case OptionRole: + return d.cfg.Role, nil + case OptionRegion: + return d.cfg.Region, nil + case OptionAccount: + return d.cfg.Account, nil + case OptionProtocol: + return d.cfg.Protocol, nil + case OptionHost: + return d.cfg.Host, nil + case OptionPort: + return strconv.Itoa(d.cfg.Port), nil + case OptionAuthType: + return d.cfg.Authenticator.String(), nil + case OptionLoginTimeout: + return strconv.FormatFloat(d.cfg.LoginTimeout.Seconds(), 'f', -1, 64), nil + case OptionRequestTimeout: + return strconv.FormatFloat(d.cfg.RequestTimeout.Seconds(), 'f', -1, 64), nil + case OptionJwtExpireTimeout: + return strconv.FormatFloat(d.cfg.JWTExpireTimeout.Seconds(), 'f', -1, 64), nil + case OptionClientTimeout: + return strconv.FormatFloat(d.cfg.ClientTimeout.Seconds(), 'f', -1, 64), nil + case OptionApplicationName: + return d.cfg.Application, nil + case OptionSSLSkipVerify: + if d.cfg.InsecureMode { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionOCSPFailOpenMode: + return strconv.FormatUint(uint64(d.cfg.OCSPFailOpen), 10), nil + case OptionAuthToken: + return d.cfg.Token, nil + case OptionAuthOktaUrl: + return d.cfg.OktaURL.String(), nil + case OptionKeepSessionAlive: + if d.cfg.KeepSessionAlive { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionDisableTelemetry: + if d.cfg.DisableTelemetry { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionClientRequestMFAToken: + if d.cfg.ClientRequestMfaToken == gosnowflake.ConfigBoolTrue { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionClientStoreTempCred: + if d.cfg.ClientStoreTemporaryCredential == gosnowflake.ConfigBoolTrue { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionLogTracing: + return d.cfg.Tracing, nil + default: + val, ok := d.cfg.Params[key] + if ok { + return *val, nil + } + } + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionInt(key string) (int64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionDouble(key string) (float64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} + func (d *database) SetOptions(cnOptions map[string]string) error { uri, ok := cnOptions[adbc.OptionKeyURI] if ok { @@ -421,6 +520,35 @@ func (d *database) SetOptions(cnOptions map[string]string) error { return nil } +func (d *database) SetOption(key string, val string) error { + // Can't set options after init + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionInt(key string, value int64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + func (d *database) Open(ctx context.Context) (adbc.Connection, error) { connector := gosnowflake.NewConnector(drv, *d.cfg) diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 011694a280..712f730e6c 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -38,10 +38,11 @@ import ( ) type SnowflakeQuirks struct { - dsn string - mem *memory.CheckedAllocator - connector gosnowflake.Connector - schemaName string + dsn string + mem *memory.CheckedAllocator + connector gosnowflake.Connector + catalogName string + schemaName string } func (s *SnowflakeQuirks) SetupDriver(t *testing.T) adbc.Driver { @@ -180,12 +181,17 @@ func (s *SnowflakeQuirks) DropTable(cnxn adbc.Connection, tblname string) error func (s *SnowflakeQuirks) Alloc() memory.Allocator { return s.mem } func (s *SnowflakeQuirks) BindParameter(_ int) string { return "?" } +func (s *SnowflakeQuirks) SupportsBulkIngest(string) bool { return true } func (s *SnowflakeQuirks) SupportsConcurrentStatements() bool { return true } +func (s *SnowflakeQuirks) SupportsCurrentCatalogSchema() bool { return true } +func (s *SnowflakeQuirks) SupportsExecuteSchema() bool { return false } +func (s *SnowflakeQuirks) SupportsGetSetOptions() bool { return true } func (s *SnowflakeQuirks) SupportsPartitionedData() bool { return false } +func (s *SnowflakeQuirks) SupportsStatistics() bool { return false } func (s *SnowflakeQuirks) SupportsTransactions() bool { return true } func (s *SnowflakeQuirks) SupportsGetParameterSchema() bool { return false } func (s *SnowflakeQuirks) SupportsDynamicParameterBinding() bool { return false } -func (s *SnowflakeQuirks) SupportsBulkIngest() bool { return true } +func (s *SnowflakeQuirks) Catalog() string { return s.catalogName } func (s *SnowflakeQuirks) DBSchema() string { return s.schemaName } func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} { switch code { @@ -197,6 +203,8 @@ func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} { return "(unknown or development build)" case adbc.InfoDriverArrowVersion: return "(unknown or development build)" + case adbc.InfoDriverADBCVersion: + return adbc.AdbcVersion1_1_0 case adbc.InfoVendorName: return "Snowflake" } @@ -225,7 +233,7 @@ func createTempSchema(uri string) string { } defer db.Close() - schemaName := "ADBC_TESTING_" + strings.ReplaceAll(uuid.New().String(), "-", "_") + schemaName := strings.ToUpper("ADBC_TESTING_" + strings.ReplaceAll(uuid.New().String(), "-", "_")) _, err = db.Exec(`CREATE SCHEMA ADBC_TESTING.` + schemaName) if err != nil { panic(err) @@ -249,14 +257,17 @@ func dropTempSchema(uri, schema string) { func withQuirks(t *testing.T, fn func(*SnowflakeQuirks)) { uri := os.Getenv("SNOWFLAKE_URI") + database := os.Getenv("SNOWFLAKE_DATABASE") if uri == "" { t.Skip("no SNOWFLAKE_URI defined, skip snowflake driver tests") + } else if database == "" { + t.Skip("no SNOWFLAKE_DATABASE defined, skip snowflake driver tests") } // avoid multiple runs clashing by operating in a fresh schema and then // dropping that schema when we're done. - q := &SnowflakeQuirks{dsn: uri, schemaName: createTempSchema(uri)} + q := &SnowflakeQuirks{dsn: uri, catalogName: database, schemaName: createTempSchema(uri)} defer dropTempSchema(uri, q.schemaName) fn(q) diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index e0d14582d9..e8707bfdb4 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -45,7 +45,7 @@ type statement struct { query string targetTable string - append bool + ingestMode string bound arrow.Record streamBind array.RecordReader @@ -73,6 +73,35 @@ func (st *statement) Close() error { return nil } +func (st *statement) GetOption(key string) (string, error) { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionInt(key string) (int64, error) { + switch key { + case OptionStatementQueueSize: + return int64(st.queueSize), nil + } + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionDouble(key string) (float64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} + // SetOption sets a string option on this statement func (st *statement) SetOption(key string, val string) error { switch key { @@ -82,9 +111,13 @@ func (st *statement) SetOption(key string, val string) error { case adbc.OptionKeyIngestMode: switch val { case adbc.OptionValueIngestModeAppend: - st.append = true + fallthrough case adbc.OptionValueIngestModeCreate: - st.append = false + fallthrough + case adbc.OptionValueIngestModeReplace: + fallthrough + case adbc.OptionValueIngestModeCreateAppend: + st.ingestMode = val default: return adbc.Error{ Msg: fmt.Sprintf("invalid statement option %s=%s", key, val), @@ -99,13 +132,7 @@ func (st *statement) SetOption(key string, val string) error { Code: adbc.StatusInvalidArgument, } } - if sz <= 0 { - return adbc.Error{ - Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be > 0", sz, key), - Code: adbc.StatusInvalidArgument, - } - } - st.queueSize = sz + return st.SetOptionInt(key, int64(sz)) case OptionStatementPrefetchConcurrency: concurrency, err := strconv.Atoi(val) if err != nil { @@ -114,13 +141,7 @@ func (st *statement) SetOption(key string, val string) error { Code: adbc.StatusInvalidArgument, } } - if concurrency <= 0 { - return adbc.Error{ - Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be > 0", concurrency, key), - Code: adbc.StatusInvalidArgument, - } - } - st.prefetchConcurrency = concurrency + return st.SetOptionInt(key, int64(concurrency)) default: return adbc.Error{ Msg: fmt.Sprintf("invalid statement option %s=%s", key, val), @@ -130,6 +151,47 @@ func (st *statement) SetOption(key string, val string) error { return nil } +func (st *statement) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (st *statement) SetOptionInt(key string, value int64) error { + switch key { + case OptionStatementQueueSize: + if value <= 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value), + Code: adbc.StatusInvalidArgument, + } + } + st.queueSize = int(value) + return nil + case OptionStatementPrefetchConcurrency: + if value <= 0 { + return adbc.Error{ + Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be > 0", value, key), + Code: adbc.StatusInvalidArgument, + } + } + st.prefetchConcurrency = int(value) + return nil + } + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (st *statement) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. @@ -196,6 +258,9 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { ) createBldr.WriteString("CREATE TABLE ") + if st.ingestMode == adbc.OptionValueIngestModeCreateAppend { + createBldr.WriteString(" IF NOT EXISTS ") + } createBldr.WriteString(st.targetTable) createBldr.WriteString(" (") @@ -237,7 +302,22 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { createBldr.WriteString(")") insertBldr.WriteString(")") - if !st.append { + switch st.ingestMode { + case adbc.OptionValueIngestModeAppend: + // Do nothing + case adbc.OptionValueIngestModeReplace: + replaceQuery := "DROP TABLE IF EXISTS " + st.targetTable + _, err := st.cnxn.cn.ExecContext(ctx, replaceQuery, nil) + if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + + fallthrough + case adbc.OptionValueIngestModeCreate: + fallthrough + case adbc.OptionValueIngestModeCreateAppend: + fallthrough + default: // create the table! createQuery := createBldr.String() _, err := st.cnxn.cn.ExecContext(ctx, createQuery, nil) diff --git a/go/adbc/drivermgr/adbc.h b/go/adbc/drivermgr/adbc.h index 154e881255..1ec2f05080 100644 --- a/go/adbc/drivermgr/adbc.h +++ b/go/adbc/drivermgr/adbc.h @@ -35,7 +35,7 @@ /// but not concurrent access. Specific implementations may permit /// multiple threads. /// -/// \version 1.0.0 +/// \version 1.1.0 #pragma once @@ -248,7 +248,24 @@ typedef uint8_t AdbcStatusCode; /// May indicate a database-side error only. #define ADBC_STATUS_UNAUTHORIZED 14 +/// \brief Inform the driver/driver manager that we are using the extended +/// AdbcError struct from ADBC 1.1.0. +/// +/// See the AdbcError documentation for usage. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA INT32_MIN + /// \brief A detailed error message for an operation. +/// +/// The caller must zero-initialize this struct (clarified in ADBC 1.1.0). +/// +/// The structure was extended in ADBC 1.1.0. Drivers and clients using ADBC +/// 1.0.0 will not have the private_data or private_driver fields. Drivers +/// should read/write these fields if and only if vendor_code is equal to +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. Clients are required to initialize +/// this struct to avoid the possibility of uninitialized values confusing the +/// driver. struct ADBC_EXPORT AdbcError { /// \brief The error message. char* message; @@ -266,8 +283,112 @@ struct ADBC_EXPORT AdbcError { /// Unlike other structures, this is an embedded callback to make it /// easier for the driver manager and driver to cooperate. void (*release)(struct AdbcError* error); + + /// \brief Opaque implementation-defined state. + /// + /// This field may not be used unless vendor_code is + /// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. If present, this field is NULLPTR + /// iff the error is unintialized/freed. + /// + /// \since ADBC API revision 1.1.0 + void* private_data; + + /// \brief The associated driver (used by the driver manager to help + /// track state). + /// + /// This field may not be used unless vendor_code is + /// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. + /// + /// \since ADBC API revision 1.1.0 + struct AdbcDriver* private_driver; }; +#ifdef __cplusplus +/// \brief A helper to initialize the full AdbcError structure. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_INIT \ + (AdbcError{nullptr, \ + ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA, \ + {0, 0, 0, 0, 0}, \ + nullptr, \ + nullptr, \ + nullptr}) +#else +/// \brief A helper to initialize the full AdbcError structure. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_INIT \ + ((struct AdbcError){ \ + NULL, ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA, {0, 0, 0, 0, 0}, NULL, NULL, NULL}) +#endif + +/// \brief The size of the AdbcError structure in ADBC 1.0.0. +/// +/// Drivers written for ADBC 1.1.0 and later should never touch more than this +/// portion of an AdbcDriver struct when vendor_code is not +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_1_0_0_SIZE (offsetof(struct AdbcError, private_data)) +/// \brief The size of the AdbcError structure in ADBC 1.1.0. +/// +/// Drivers written for ADBC 1.1.0 and later should never touch more than this +/// portion of an AdbcDriver struct when vendor_code is +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_1_1_0_SIZE (sizeof(struct AdbcError)) + +/// \brief Extra key-value metadata for an error. +/// +/// The fields here are owned by the driver and should not be freed. The +/// fields here are invalidated when the release callback in AdbcError is +/// called. +/// +/// \since ADBC API revision 1.1.0 +struct ADBC_EXPORT AdbcErrorDetail { + /// \brief The metadata key. + const char* key; + /// \brief The binary metadata value. + const uint8_t* value; + /// \brief The length of the metadata value. + size_t value_length; +}; + +/// \brief Get the number of metadata values available in an error. +/// +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +int AdbcErrorGetDetailCount(const struct AdbcError* error); + +/// \brief Get a metadata value in an error by index. +/// +/// If index is invalid, returns an AdbcErrorDetail initialized with NULL/0 +/// fields. +/// +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index); + +/// \brief Get an ADBC error from an ArrowArrayStream created by a driver. +/// +/// This allows retrieving error details and other metadata that would +/// normally be suppressed by the Arrow C Stream Interface. +/// +/// The caller MUST NOT release the error; it is managed by the release +/// callback in the stream itself. +/// +/// \param[in] stream The stream to query. +/// \param[out] status The ADBC status code, or ADBC_STATUS_OK if there is no +/// error. Not written to if the stream does not contain an ADBC error or +/// if the pointer is NULL. +/// \return NULL if not supported. +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + /// @} /// \defgroup adbc-constants Constants @@ -279,6 +400,14 @@ struct ADBC_EXPORT AdbcError { /// point to an AdbcDriver. #define ADBC_VERSION_1_0_0 1000000 +/// \brief ADBC revision 1.1.0. +/// +/// When passed to an AdbcDriverInitFunc(), the driver parameter must +/// point to an AdbcDriver. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_VERSION_1_1_0 1001000 + /// \brief Canonical option value for enabling an option. /// /// For use as the value in SetOption calls. @@ -288,6 +417,34 @@ struct ADBC_EXPORT AdbcError { /// For use as the value in SetOption calls. #define ADBC_OPTION_VALUE_DISABLED "false" +/// \brief Canonical option name for URIs. +/// +/// Should be used as the expected option name to specify a URI for +/// any ADBC driver. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_URI "uri" +/// \brief Canonical option name for usernames. +/// +/// Should be used as the expected option name to specify a username +/// to a driver for authentication. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_USERNAME "username" +/// \brief Canonical option name for passwords. +/// +/// Should be used as the expected option name to specify a password +/// for authentication to a driver. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_PASSWORD "password" + /// \brief The database vendor/product name (e.g. the server name). /// (type: utf8). /// @@ -315,6 +472,15 @@ struct ADBC_EXPORT AdbcError { /// /// \see AdbcConnectionGetInfo #define ADBC_INFO_DRIVER_ARROW_VERSION 102 +/// \brief The driver ADBC API version (type: int64). +/// +/// The value should be one of the ADBC_VERSION constants. +/// +/// \since ADBC API revision 1.1.0 +/// \see AdbcConnectionGetInfo +/// \see ADBC_VERSION_1_0_0 +/// \see ADBC_VERSION_1_1_0 +#define ADBC_INFO_DRIVER_ADBC_VERSION 103 /// \brief Return metadata on catalogs, schemas, tables, and columns. /// @@ -337,18 +503,133 @@ struct ADBC_EXPORT AdbcError { /// \see AdbcConnectionGetObjects #define ADBC_OBJECT_DEPTH_COLUMNS ADBC_OBJECT_DEPTH_ALL +/// \defgroup adbc-table-statistics ADBC Statistic Types +/// Standard statistic names for AdbcConnectionGetStatistics. +/// @{ + +/// \brief The dictionary-encoded name of the average byte width statistic. +#define ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY 0 +/// \brief The average byte width statistic. The average size in bytes of a +/// row in the column. Value type is float64. +/// +/// For example, this is roughly the average length of a string for a string +/// column. +#define ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the distinct value count statistic. +#define ADBC_STATISTIC_DISTINCT_COUNT_KEY 1 +/// \brief The distinct value count (NDV) statistic. The number of distinct +/// values in the column. Value type is int64 (when not approximate) or +/// float64 (when approximate). +#define ADBC_STATISTIC_DISTINCT_COUNT_NAME "adbc.statistic.distinct_count" +/// \brief The dictionary-encoded name of the max byte width statistic. +#define ADBC_STATISTIC_MAX_BYTE_WIDTH_KEY 2 +/// \brief The max byte width statistic. The maximum size in bytes of a row +/// in the column. Value type is int64 (when not approximate) or float64 +/// (when approximate). +/// +/// For example, this is the maximum length of a string for a string column. +#define ADBC_STATISTIC_MAX_BYTE_WIDTH_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the max value statistic. +#define ADBC_STATISTIC_MAX_VALUE_KEY 3 +/// \brief The max value statistic. Value type is column-dependent. +#define ADBC_STATISTIC_MAX_VALUE_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the min value statistic. +#define ADBC_STATISTIC_MIN_VALUE_KEY 4 +/// \brief The min value statistic. Value type is column-dependent. +#define ADBC_STATISTIC_MIN_VALUE_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the null count statistic. +#define ADBC_STATISTIC_NULL_COUNT_KEY 5 +/// \brief The null count statistic. The number of values that are null in +/// the column. Value type is int64 (when not approximate) or float64 +/// (when approximate). +#define ADBC_STATISTIC_NULL_COUNT_NAME "adbc.statistic.null_count" +/// \brief The dictionary-encoded name of the row count statistic. +#define ADBC_STATISTIC_ROW_COUNT_KEY 6 +/// \brief The row count statistic. The number of rows in the column or +/// table. Value type is int64 (when not approximate) or float64 (when +/// approximate). +#define ADBC_STATISTIC_ROW_COUNT_NAME "adbc.statistic.row_count" +/// @} + /// \brief The name of the canonical option for whether autocommit is /// enabled. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_AUTOCOMMIT "adbc.connection.autocommit" /// \brief The name of the canonical option for whether the current /// connection should be restricted to being read-only. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_READ_ONLY "adbc.connection.readonly" +/// \brief The name of the canonical option for the current catalog. +/// +/// The type is char*. +/// +/// \see AdbcConnectionGetOption +/// \see AdbcConnectionSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_CONNECTION_OPTION_CURRENT_CATALOG "adbc.connection.catalog" + +/// \brief The name of the canonical option for the current schema. +/// +/// The type is char*. +/// +/// \see AdbcConnectionGetOption +/// \see AdbcConnectionSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA "adbc.connection.db_schema" + +/// \brief The name of the canonical option for making query execution +/// nonblocking. +/// +/// When enabled, AdbcStatementExecutePartitions will return +/// partitions as soon as they are available, instead of returning +/// them all at the end. When there are no more to return, it will +/// return an empty set of partitions. AdbcStatementExecuteQuery and +/// AdbcStatementExecuteSchema are not affected. +/// +/// The default is ADBC_OPTION_VALUE_DISABLED. +/// +/// The type is char*. +/// +/// \see AdbcStatementSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_INCREMENTAL "adbc.statement.exec.incremental" + +/// \brief The name of the option for getting the progress of a query. +/// +/// The value is not necessarily in any particular range or have any +/// particular units. (For example, it might be a percentage, bytes of data, +/// rows of data, number of workers, etc.) The max value can be retrieved via +/// ADBC_STATEMENT_OPTION_MAX_PROGRESS. This represents the progress of +/// execution, not of consumption (i.e., it is independent of how much of the +/// result set has been read by the client via ArrowArrayStream.get_next().) +/// +/// The type is double. +/// +/// \see AdbcStatementGetOptionDouble +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_PROGRESS "adbc.statement.exec.progress" + +/// \brief The name of the option for getting the maximum progress of a query. +/// +/// This is the value of ADBC_STATEMENT_OPTION_PROGRESS for a completed query. +/// If not supported, or if the value is nonpositive, then the maximum is not +/// known. (For instance, the query may be fully streaming and the driver +/// does not know when the result set will end.) +/// +/// The type is double. +/// +/// \see AdbcStatementGetOptionDouble +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_MAX_PROGRESS "adbc.statement.exec.max_progress" + /// \brief The name of the canonical option for setting the isolation /// level of a transaction. /// @@ -357,6 +638,8 @@ struct ADBC_EXPORT AdbcError { /// isolation level is not supported by a driver, it should return an /// appropriate error. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_ISOLATION_LEVEL \ "adbc.connection.transaction.isolation_level" @@ -449,8 +732,12 @@ struct ADBC_EXPORT AdbcError { /// exist. If the table exists but has a different schema, /// ADBC_STATUS_ALREADY_EXISTS should be raised. Else, data should be /// appended to the target table. +/// +/// The type is char*. #define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table" /// \brief Whether to create (the default) or append. +/// +/// The type is char*. #define ADBC_INGEST_OPTION_MODE "adbc.ingest.mode" /// \brief Create the table and insert data; error if the table exists. #define ADBC_INGEST_OPTION_MODE_CREATE "adbc.ingest.mode.create" @@ -458,6 +745,15 @@ struct ADBC_EXPORT AdbcError { /// table does not exist (ADBC_STATUS_NOT_FOUND) or does not match /// the schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). #define ADBC_INGEST_OPTION_MODE_APPEND "adbc.ingest.mode.append" +/// \brief Create the table and insert data; drop the original table +/// if it already exists. +/// \since ADBC API revision 1.1.0 +#define ADBC_INGEST_OPTION_MODE_REPLACE "adbc.ingest.mode.replace" +/// \brief Insert data; create the table if it does not exist, or +/// error if the table exists, but the schema does not match the +/// schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). +/// \since ADBC API revision 1.1.0 +#define ADBC_INGEST_OPTION_MODE_CREATE_APPEND "adbc.ingest.mode.create_append" /// @} @@ -624,7 +920,7 @@ struct ADBC_EXPORT AdbcDriver { AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*); AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection*, struct AdbcError*); - AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t, + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, const uint32_t*, size_t, struct ArrowArrayStream*, struct AdbcError*); AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection*, int, const char*, const char*, const char*, const char**, @@ -667,8 +963,108 @@ struct ADBC_EXPORT AdbcDriver { struct AdbcError*); AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const uint8_t*, size_t, struct AdbcError*); + + /// \defgroup adbc-1.1.0 ADBC API Revision 1.1.0 + /// + /// Functions added in ADBC 1.1.0. For backwards compatibility, + /// these members must not be accessed unless the version passed to + /// the AdbcDriverInitFunc is greater than or equal to + /// ADBC_VERSION_1_1_0. + /// + /// For a 1.0.0 driver being loaded by a 1.1.0 driver manager: the + /// 1.1.0 manager will allocate the new, expanded AdbcDriver struct + /// and attempt to have the driver initialize it with + /// ADBC_VERSION_1_1_0. This must return an error, after which the + /// driver will try again with ADBC_VERSION_1_0_0. The driver must + /// not access the new fields, which will carry undefined values. + /// + /// For a 1.1.0 driver being loaded by a 1.0.0 driver manager: the + /// 1.0.0 manager will allocate the old AdbcDriver struct and + /// attempt to have the driver initialize it with + /// ADBC_VERSION_1_0_0. The driver must not access the new fields, + /// and should initialize the old fields. + /// + /// @{ + + int (*ErrorGetDetailCount)(const struct AdbcError* error); + struct AdbcErrorDetail (*ErrorGetDetail)(const struct AdbcError* error, int index); + const struct AdbcError* (*ErrorFromArrayStream)(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + + AdbcStatusCode (*DatabaseGetOption)(struct AdbcDatabase*, const char*, char*, size_t*, + struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionBytes)(struct AdbcDatabase*, const char*, uint8_t*, + size_t*, struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionDouble)(struct AdbcDatabase*, const char*, double*, + struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionInt)(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionBytes)(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionDouble)(struct AdbcDatabase*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionInt)(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + + AdbcStatusCode (*ConnectionCancel)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOption)(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionBytes)(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionDouble)(struct AdbcConnection*, const char*, + double*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionInt)(struct AdbcConnection*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*ConnectionGetStatistics)(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetStatisticNames)(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionBytes)(struct AdbcConnection*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionDouble)(struct AdbcConnection*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionInt)(struct AdbcConnection*, const char*, int64_t, + struct AdbcError*); + + AdbcStatusCode (*StatementCancel)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementExecuteSchema)(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOption)(struct AdbcStatement*, const char*, char*, size_t*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOptionBytes)(struct AdbcStatement*, const char*, uint8_t*, + size_t*, struct AdbcError*); + AdbcStatusCode (*StatementGetOptionDouble)(struct AdbcStatement*, const char*, double*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOptionInt)(struct AdbcStatement*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*StatementSetOptionBytes)(struct AdbcStatement*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*StatementSetOptionDouble)(struct AdbcStatement*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*StatementSetOptionInt)(struct AdbcStatement*, const char*, int64_t, + struct AdbcError*); + + /// @} }; +/// \brief The size of the AdbcDriver structure in ADBC 1.0.0. +/// Drivers written for ADBC 1.1.0 and later should never touch more +/// than this portion of an AdbcDriver struct when given +/// ADBC_VERSION_1_0_0. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_DRIVER_1_0_0_SIZE (offsetof(struct AdbcDriver, ErrorGetDetailCount)) + +/// \brief The size of the AdbcDriver structure in ADBC 1.1.0. +/// Drivers written for ADBC 1.1.0 and later should never touch more +/// than this portion of an AdbcDriver struct when given +/// ADBC_VERSION_1_1_0. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_DRIVER_1_1_0_SIZE (sizeof(struct AdbcDriver)) + /// @} /// \addtogroup adbc-database @@ -684,16 +1080,189 @@ struct ADBC_EXPORT AdbcDriver { ADBC_EXPORT AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error); +/// \brief Get a string option of the database. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the database. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a double option of the database. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the double +/// representation of an integer option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error); + +/// \brief Get an integer option of the database. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the integer +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error); + /// \brief Set a char* option. /// /// Options may be set before AdbcDatabaseInit. Some drivers may /// support setting options after initialization as well. /// +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. /// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized ADBC_EXPORT AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error); + +/// \brief Set a double option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error); + +/// \brief Set an integer option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error); + /// \brief Finish setting options and initialize the database. /// /// Some drivers may support setting options after initialization @@ -730,11 +1299,65 @@ AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, /// Options may be set before AdbcConnectionInit. Some drivers may /// support setting options after initialization as well. /// +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. /// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized ADBC_EXPORT AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a connection. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error); + +/// \brief Set an integer option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error); + +/// \brief Set a double option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error); + /// \brief Finish setting options and initialize the connection. /// /// Some drivers may support setting options after initialization @@ -752,6 +1375,30 @@ ADBC_EXPORT AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, struct AdbcError* error); +/// \brief Cancel the in-progress operation on a connection. +/// +/// This can be called during AdbcConnectionGetObjects (or similar), +/// or while consuming an ArrowArrayStream returned from such. +/// Calling this function should make the other functions return +/// ADBC_STATUS_CANCELLED (from ADBC functions) or ECANCELED (from +/// methods of ArrowArrayStream). (It is not guaranteed to, for +/// instance, the result set may be buffered in memory already.) +/// +/// This must always be thread-safe (other operations are not). It is +/// not necessarily signal-safe. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] connection The connection to cancel. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_INVALID_STATE if there is no operation to cancel. +/// \return ADBC_STATUS_UNKNOWN if the operation could not be cancelled. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error); + /// \defgroup adbc-connection-metadata Metadata /// Functions for retrieving metadata about the database. /// @@ -765,6 +1412,8 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// concurrent active statements and it must execute a SQL query /// internally in order to implement the metadata function). /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// Some functions accept "search pattern" arguments, which are /// strings that can contain the special character "%" to match zero /// or more characters, or "_" to match exactly one character. (See @@ -799,6 +1448,10 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// for ADBC usage. Drivers/vendors will ignore requests for /// unrecognized codes (the row will be omitted from the result). /// +/// Since ADBC 1.1.0: the range [500, 1_000) is reserved for "XDBC" +/// information, which is the same metadata provided by the same info +/// code range in the Arrow Flight SQL GetSqlInfo RPC. +/// /// \param[in] connection The connection to query. /// \param[in] info_codes A list of metadata codes to fetch, or NULL /// to fetch all. @@ -808,7 +1461,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// \param[out] error Error details, if an error occurs. ADBC_EXPORT AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error); @@ -891,6 +1544,8 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, /// | fk_table | utf8 not null | /// | fk_column_name | utf8 not null | /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The database connection. /// \param[in] depth The level of nesting to display. If 0, display /// all levels. If 1, display only catalogs (i.e. catalog_schemas @@ -922,6 +1577,212 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct ArrowArrayStream* out, struct AdbcError* error); +/// \brief Get a string option of the connection. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the connection. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error); + +/// \brief Get an integer option of the connection. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error); + +/// \brief Get a double option of the connection. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error); + +/// \brief Get statistics about the data distribution of table(s). +/// +/// The result is an Arrow dataset with the following schema: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | catalog_name | utf8 | +/// | catalog_db_schemas | list not null | +/// +/// DB_SCHEMA_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | db_schema_name | utf8 | +/// | db_schema_statistics | list not null | +/// +/// STATISTICS_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | Comments | +/// |--------------------------|----------------------------------| -------- | +/// | table_name | utf8 not null | | +/// | column_name | utf8 | (1) | +/// | statistic_key | int16 not null | (2) | +/// | statistic_value | VALUE_SCHEMA not null | | +/// | statistic_is_approximate | bool not null | (3) | +/// +/// 1. If null, then the statistic applies to the entire table. +/// 2. A dictionary-encoded statistic name (although we do not use the Arrow +/// dictionary type). Values in [0, 1024) are reserved for ADBC. Other +/// values are for implementation-specific statistics. For the definitions +/// of predefined statistic types, see \ref adbc-table-statistics. To get +/// driver-specific statistic names, use AdbcConnectionGetStatisticNames. +/// 3. If true, then the value is approximate or best-effort. +/// +/// VALUE_SCHEMA is a dense union with members: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | int64 | int64 | +/// | uint64 | uint64 | +/// | float64 | float64 | +/// | binary | binary | +/// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] catalog The catalog (or nullptr). May be a search +/// pattern (see section documentation). +/// \param[in] db_schema The database schema (or nullptr). May be a +/// search pattern (see section documentation). +/// \param[in] table_name The table name (or nullptr). May be a +/// search pattern (see section documentation). +/// \param[in] approximate If zero, request exact values of +/// statistics, else allow for best-effort, approximate, or cached +/// values. The database may return approximate values regardless, +/// as indicated in the result. Requesting exact values may be +/// expensive or unsupported. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error); + +/// \brief Get the names of statistics specific to this driver. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// Field Name | Field Type +/// ---------------|---------------- +/// statistic_name | utf8 not null +/// statistic_key | int16 not null +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error); + /// \brief Get the Arrow schema of a table. /// /// \param[in] connection The database connection. @@ -945,6 +1806,8 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, /// ---------------|-------------- /// table_type | utf8 not null /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The database connection. /// \param[out] out The result set. /// \param[out] error Error details, if an error occurs. @@ -973,6 +1836,8 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, /// /// A partition can be retrieved from AdbcPartitions. /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The connection to use. This does not have /// to be the same connection that the partition was created on. /// \param[in] serialized_partition The partition descriptor. @@ -1042,7 +1907,11 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, /// \brief Execute a statement and get the results. /// -/// This invalidates any prior result sets. +/// This invalidates any prior result sets. This AdbcStatement must +/// outlive the returned ArrowArrayStream. +/// +/// Since ADBC 1.1.0: releasing the returned ArrowArrayStream without +/// consuming it fully is equivalent to calling AdbcStatementCancel. /// /// \param[in] statement The statement to execute. /// \param[out] out The results. Pass NULL if the client does not @@ -1056,6 +1925,27 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, struct ArrowArrayStream* out, int64_t* rows_affected, struct AdbcError* error); +/// \brief Get the schema of the result set of a query without +/// executing it. +/// +/// This invalidates any prior result sets. +/// +/// Depending on the driver, this may require first executing +/// AdbcStatementPrepare. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to execute. +/// \param[out] out The result schema. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the driver does not support this. +ADBC_EXPORT +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error); + /// \brief Turn this statement into a prepared statement to be /// executed multiple times. /// @@ -1138,6 +2028,158 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, struct ArrowArrayStream* stream, struct AdbcError* error); +/// \brief Cancel execution of an in-progress query. +/// +/// This can be called during AdbcStatementExecuteQuery (or similar), +/// or while consuming an ArrowArrayStream returned from such. +/// Calling this function should make the other functions return +/// ADBC_STATUS_CANCELLED (from ADBC functions) or ECANCELED (from +/// methods of ArrowArrayStream). (It is not guaranteed to, for +/// instance, the result set may be buffered in memory already.) +/// +/// This must always be thread-safe (other operations are not). It is +/// not necessarily signal-safe. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to cancel. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_INVALID_STATE if there is no query to cancel. +/// \return ADBC_STATUS_UNKNOWN if the query could not be cancelled. +ADBC_EXPORT +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error); + +/// \brief Get a string option of the statement. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the statement. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error); + +/// \brief Get an integer option of the statement. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error); + +/// \brief Get a double option of the statement. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error); + /// \brief Get the schema for bound parameters. /// /// This retrieves an Arrow schema describing the number, names, and @@ -1159,10 +2201,58 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct AdbcError* error); /// \brief Set a string option on a statement. +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized. ADBC_EXPORT AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error); + +/// \brief Set an integer option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error); + +/// \brief Set a double option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error); + /// \addtogroup adbc-statement-partition /// @{ @@ -1198,7 +2288,15 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, /// driver. /// /// Although drivers may choose any name for this function, the -/// recommended name is "AdbcDriverInit". +/// recommended name is "AdbcDriverInit", or a name derived from the +/// name of the driver's shared library as follows: remove the 'lib' +/// prefix (on Unix systems) and all file extensions, then PascalCase +/// the driver name, append Init, and prepend Adbc (if not already +/// there). For example: +/// +/// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit +/// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit +/// - proprietary_driver.dll -> AdbcProprietaryDriverInit /// /// \param[in] version The ADBC revision to attempt to initialize (see /// ADBC_VERSION_1_0_0). diff --git a/go/adbc/drivermgr/adbc_driver_manager.cc b/go/adbc/drivermgr/adbc_driver_manager.cc index d2929e2129..516bf9bbf7 100644 --- a/go/adbc/drivermgr/adbc_driver_manager.cc +++ b/go/adbc/drivermgr/adbc_driver_manager.cc @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include #include @@ -90,17 +92,141 @@ void SetError(struct AdbcError* error, const std::string& message) { // Driver state -/// Hold the driver DLL and the driver release callback in the driver struct. -struct ManagerDriverState { - // The original release callback - AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); +/// A driver DLL. +struct ManagedLibrary { + ManagedLibrary() : handle(nullptr) {} + ManagedLibrary(ManagedLibrary&& other) : handle(other.handle) { + other.handle = nullptr; + } + ManagedLibrary(const ManagedLibrary&) = delete; + ManagedLibrary& operator=(const ManagedLibrary&) = delete; + ManagedLibrary& operator=(ManagedLibrary&& other) noexcept { + this->handle = other.handle; + other.handle = nullptr; + return *this; + } + + ~ManagedLibrary() { Release(); } + + void Release() { + // TODO(apache/arrow-adbc#204): causes tests to segfault + // Need to refcount the driver DLL; also, errors may retain a reference to + // release() from the DLL - how to handle this? + } + + AdbcStatusCode Load(const char* library, struct AdbcError* error) { + std::string error_message; +#if defined(_WIN32) + HMODULE handle = LoadLibraryExA(library, NULL, 0); + if (!handle) { + error_message += library; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + + std::string full_driver_name = library; + full_driver_name += ".dll"; + handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); + if (!handle) { + error_message += '\n'; + error_message += full_driver_name; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + } + } + if (!handle) { + SetError(error, error_message); + return ADBC_STATUS_INTERNAL; + } else { + this->handle = handle; + } +#else + static const std::string kPlatformLibraryPrefix = "lib"; +#if defined(__APPLE__) + static const std::string kPlatformLibrarySuffix = ".dylib"; +#else + static const std::string kPlatformLibrarySuffix = ".so"; +#endif // defined(__APPLE__) + + void* handle = dlopen(library, RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message = "dlopen() failed: "; + error_message += dlerror(); + + // If applicable, append the shared library prefix/extension and + // try again (this way you don't have to hardcode driver names by + // platform in the application) + const std::string driver_str = library; + + std::string full_driver_name; + if (driver_str.size() < kPlatformLibraryPrefix.size() || + driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != + 0) { + full_driver_name += kPlatformLibraryPrefix; + } + full_driver_name += library; + if (driver_str.size() < kPlatformLibrarySuffix.size() || + driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix) != 0) { + full_driver_name += kPlatformLibrarySuffix; + } + handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message += "\ndlopen() failed: "; + error_message += dlerror(); + } + } + if (handle) { + this->handle = handle; + } else { + return ADBC_STATUS_INTERNAL; + } +#endif // defined(_WIN32) + return ADBC_STATUS_OK; + } + + AdbcStatusCode Lookup(const char* name, void** func, struct AdbcError* error) { +#if defined(_WIN32) + void* load_handle = reinterpret_cast(GetProcAddress(handle, name)); + if (!load_handle) { + std::string message = "GetProcAddress("; + message += name; + message += ") failed: "; + GetWinError(&message); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } +#else + void* load_handle = dlsym(handle, name); + if (!load_handle) { + std::string message = "dlsym("; + message += name; + message += ") failed: "; + message += dlerror(); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } +#endif // defined(_WIN32) + *func = load_handle; + return ADBC_STATUS_OK; + } #if defined(_WIN32) // The loaded DLL HMODULE handle; +#else + void* handle; #endif // defined(_WIN32) }; +/// Hold the driver DLL and the driver release callback in the driver struct. +struct ManagerDriverState { + // The original release callback + AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); + + ManagedLibrary handle; +}; + /// Unload the driver DLL. static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* error) { AdbcStatusCode status = ADBC_STATUS_OK; @@ -112,35 +238,132 @@ static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* if (state->driver_release) { status = state->driver_release(driver, error); } - -#if defined(_WIN32) - // TODO(apache/arrow-adbc#204): causes tests to segfault - // if (!FreeLibrary(state->handle)) { - // std::string message = "FreeLibrary() failed: "; - // GetWinError(&message); - // SetError(error, message); - // } -#endif // defined(_WIN32) + state->handle.Release(); driver->private_manager = nullptr; delete state; return status; } +// ArrowArrayStream wrapper to support AdbcErrorFromArrayStream + +struct ErrorArrayStream { + struct ArrowArrayStream stream; + struct AdbcDriver* private_driver; +}; + +void ErrorArrayStreamRelease(struct ArrowArrayStream* stream) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return; + + auto* private_data = reinterpret_cast(stream->private_data); + private_data->stream.release(&private_data->stream); + delete private_data; + std::memset(stream, 0, sizeof(*stream)); +} + +const char* ErrorArrayStreamGetLastError(struct ArrowArrayStream* stream) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return nullptr; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_last_error(&private_data->stream); +} + +int ErrorArrayStreamGetNext(struct ArrowArrayStream* stream, struct ArrowArray* array) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_next(&private_data->stream, array); +} + +int ErrorArrayStreamGetSchema(struct ArrowArrayStream* stream, + struct ArrowSchema* schema) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_schema(&private_data->stream, schema); +} + // Default stubs +int ErrorGetDetailCount(const struct AdbcError* error) { return 0; } + +struct AdbcErrorDetail ErrorGetDetail(const struct AdbcError* error, int index) { + return {nullptr, nullptr, 0}; +} + +const struct AdbcError* ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return nullptr; +} + +void ErrorArrayStreamInit(struct ArrowArrayStream* out, + struct AdbcDriver* private_driver) { + if (!out || !out->release || + // Don't bother wrapping if driver didn't claim support + private_driver->ErrorFromArrayStream == ErrorFromArrayStream) { + return; + } + struct ErrorArrayStream* private_data = new ErrorArrayStream; + private_data->stream = *out; + private_data->private_driver = private_driver; + out->get_last_error = ErrorArrayStreamGetLastError; + out->get_next = ErrorArrayStreamGetNext; + out->get_schema = ErrorArrayStreamGetSchema; + out->release = ErrorArrayStreamRelease; + out->private_data = private_data; +} + +AdbcStatusCode DatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode DatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode DatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode DatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode DatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode ConnectionCommit(struct AdbcConnection*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } -AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, uint32_t* info_codes, - size_t info_codes_length, struct ArrowArrayStream* out, - struct AdbcError* error) { +AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, + const uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* out, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } @@ -150,6 +373,39 @@ AdbcStatusCode ConnectionGetObjects(struct AdbcConnection*, int, const char*, co return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode ConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionInt(struct AdbcConnection* connection, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetStatistics(struct AdbcConnection*, const char*, const char*, + const char*, char, struct ArrowArrayStream*, + struct AdbcError*) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, struct AdbcError*) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection*, const char*, const char*, const char*, struct ArrowSchema*, struct AdbcError* error) { @@ -178,11 +434,31 @@ AdbcStatusCode ConnectionSetOption(struct AdbcConnection*, const char*, const ch return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode ConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionSetOptionInt(struct AdbcConnection* connection, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementBind(struct AdbcStatement*, struct ArrowArray*, struct ArrowSchema*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementCancel(struct AdbcStatement* statement, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -191,6 +467,33 @@ AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionBytes(struct AdbcStatement* statement, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionDouble(struct AdbcStatement* statement, const char* key, + double* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -206,6 +509,21 @@ AdbcStatusCode StatementSetOption(struct AdbcStatement*, const char*, const char return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementSetOptionBytes(struct AdbcStatement*, const char*, const uint8_t*, + size_t, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementSetOptionDouble(struct AdbcStatement* statement, const char* key, + double value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement*, const char*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; @@ -219,20 +537,129 @@ AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement*, const uint8_t*, /// Temporary state while the database is being configured. struct TempDatabase { std::unordered_map options; + std::unordered_map bytes_options; + std::unordered_map int_options; + std::unordered_map double_options; std::string driver; - // Default name (see adbc.h) - std::string entrypoint = "AdbcDriverInit"; + std::string entrypoint; AdbcDriverInitFunc init_func = nullptr; }; /// Temporary state while the database is being configured. struct TempConnection { std::unordered_map options; + std::unordered_map bytes_options; + std::unordered_map int_options; + std::unordered_map double_options; }; + +static const char kDefaultEntrypoint[] = "AdbcDriverInit"; } // namespace +// Other helpers (intentionally not in an anonymous namespace so they can be tested) + +ADBC_EXPORT +std::string AdbcDriverManagerDefaultEntrypoint(const std::string& driver) { + /// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit + /// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit + /// - proprietary_driver.dll -> AdbcProprietaryDriverInit + + // Potential path -> filename + // Treat both \ and / as directory separators on all platforms for simplicity + std::string filename; + { + size_t pos = driver.find_last_of("/\\"); + if (pos != std::string::npos) { + filename = driver.substr(pos + 1); + } else { + filename = driver; + } + } + + // Remove all extensions + { + size_t pos = filename.find('.'); + if (pos != std::string::npos) { + filename = filename.substr(0, pos); + } + } + + // Remove lib prefix + // https://stackoverflow.com/q/1878001/262727 + if (filename.rfind("lib", 0) == 0) { + filename = filename.substr(3); + } + + // Split on underscores, hyphens + // Capitalize and join + std::string entrypoint; + entrypoint.reserve(filename.size()); + size_t pos = 0; + while (pos < filename.size()) { + size_t prev = pos; + pos = filename.find_first_of("-_", pos); + // if pos == npos this is the entire filename + std::string token = filename.substr(prev, pos - prev); + // capitalize first letter + token[0] = std::toupper(static_cast(token[0])); + + entrypoint += token; + + if (pos != std::string::npos) { + pos++; + } + } + + if (entrypoint.rfind("Adbc", 0) != 0) { + entrypoint = "Adbc" + entrypoint; + } + entrypoint += "Init"; + + return entrypoint; +} + // Direct implementations of API methods +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && + error->private_driver) { + return error->private_driver->ErrorGetDetailCount(error); + } + return 0; +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && + error->private_driver) { + return error->private_driver->ErrorGetDetail(error, index); + } + return {nullptr, nullptr, 0}; +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + if (!stream->private_data || stream->release != ErrorArrayStreamRelease) { + return nullptr; + } + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->private_driver->ErrorFromArrayStream(&private_data->stream, + status); +} + +#define INIT_ERROR(ERROR, SOURCE) \ + if ((ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ + (ERROR)->private_driver = (SOURCE)->private_driver; \ + } + +#define WRAP_STREAM(EXPR, OUT, SOURCE) \ + if (!(OUT)) { \ + /* Happens for ExecuteQuery where out is optional */ \ + return EXPR; \ + } \ + AdbcStatusCode status_code = EXPR; \ + ErrorArrayStreamInit(OUT, (SOURCE)->private_driver); \ + return status_code; + AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { // Allocate a temporary structure to store options pre-Init database->private_data = new TempDatabase(); @@ -240,9 +667,93 @@ AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* return ADBC_STATUS_OK; } +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOption(database, key, value, length, + error); + } + const auto* args = reinterpret_cast(database->private_data); + const std::string* result = nullptr; + if (std::strcmp(key, "driver") == 0) { + result = &args->driver; + } else if (std::strcmp(key, "entrypoint") == 0) { + result = &args->entrypoint; + } else { + const auto it = args->options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + result = &it->second; + } + + if (*length <= result->size() + 1) { + // Enough space + std::memcpy(value, result->c_str(), result->size() + 1); + } + *length = result->size() + 1; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionBytes(database, key, value, length, + error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->bytes_options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + const std::string& result = it->second; + + if (*length <= result.size()) { + // Enough space + std::memcpy(value, result.c_str(), result.size()); + } + *length = result.size(); + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionInt(database, key, value, error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->int_options.find(key); + if (it == args->int_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionDouble(database, key, value, error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->double_options.find(key); + if (it == args->double_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { if (database->private_driver) { + INIT_ERROR(error, database); return database->private_driver->DatabaseSetOption(database, key, value, error); } @@ -257,6 +768,44 @@ AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* return ADBC_STATUS_OK; } +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionBytes(database, key, value, length, + error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->bytes_options[key] = std::string(reinterpret_cast(value), length); + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionInt(database, key, value, error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->int_options[key] = value; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionDouble(database, key, value, error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->double_options[key] = value; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* database, AdbcDriverInitFunc init_func, struct AdbcError* error) { @@ -288,11 +837,14 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* // So we don't confuse a driver into thinking it's initialized already database->private_data = nullptr; if (args->init_func) { - status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_0_0, + status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_1_0, database->private_driver, error); - } else { + } else if (!args->entrypoint.empty()) { status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), - ADBC_VERSION_1_0_0, database->private_driver, error); + ADBC_VERSION_1_1_0, database->private_driver, error); + } else { + status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, + database->private_driver, error); } if (status != ADBC_STATUS_OK) { // Restore private_data so it will be released by AdbcDatabaseRelease @@ -313,25 +865,49 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_driver = nullptr; return status; } - for (const auto& option : args->options) { + auto options = std::move(args->options); + auto bytes_options = std::move(args->bytes_options); + auto int_options = std::move(args->int_options); + auto double_options = std::move(args->double_options); + delete args; + + INIT_ERROR(error, database); + for (const auto& option : options) { status = database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); - if (status != ADBC_STATUS_OK) { - delete args; - // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - // Should be redundant, but ensure that AdbcDatabaseRelease - // below doesn't think that it contains a TempDatabase - database->private_data = nullptr; - return status; + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : bytes_options) { + status = database->private_driver->DatabaseSetOptionBytes( + database, option.first.c_str(), + reinterpret_cast(option.second.data()), option.second.size(), + error); + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : int_options) { + status = database->private_driver->DatabaseSetOptionInt( + database, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : double_options) { + status = database->private_driver->DatabaseSetOptionDouble( + database, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) break; + } + + if (status != ADBC_STATUS_OK) { + // Release the database + std::ignore = database->private_driver->DatabaseRelease(database, error); + if (database->private_driver->release) { + database->private_driver->release(database->private_driver, error); } + delete database->private_driver; + database->private_driver = nullptr; + // Should be redundant, but ensure that AdbcDatabaseRelease + // below doesn't think that it contains a TempDatabase + database->private_data = nullptr; + return status; } - delete args; return database->private_driver->DatabaseInit(database, error); } @@ -346,6 +922,7 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, } return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, database); auto status = database->private_driver->DatabaseRelease(database, error); if (database->private_driver->release) { database->private_driver->release(database->private_driver, error); @@ -356,23 +933,35 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, return status; } +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionCancel(connection, error); +} + AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetInfo(connection, info_codes, - info_codes_length, out, error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionGetInfo( + connection, info_codes, info_codes_length, out, error), + out, connection); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -384,9 +973,132 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetObjects( - connection, depth, catalog, db_schema, table_name, table_types, column_name, stream, - error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionGetObjects( + connection, depth, catalog, db_schema, table_name, table_types, + column_name, stream, error), + stream, connection); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + if (*length >= it->second.size() + 1) { + std::memcpy(value, it->second.c_str(), it->second.size() + 1); + } + *length = it->second.size() + 1; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOption(connection, key, value, length, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->bytes_options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + if (*length >= it->second.size() + 1) { + std::memcpy(value, it->second.data(), it->second.size() + 1); + } + *length = it->second.size() + 1; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionBytes(connection, key, value, + length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->int_options.find(key); + if (it == args->int_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionInt(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->double_options.find(key); + if (it == args->double_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionDouble(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetStatistics( + connection, catalog, db_schema, table_name, approximate == 1, out, error), + out, connection); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetStatisticNames(connection, out, error), + out, connection); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -397,6 +1109,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionGetTableSchema( connection, catalog, db_schema, table_name, schema, error); } @@ -407,7 +1120,10 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetTableTypes(connection, stream, error), + stream, connection); } AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, @@ -423,6 +1139,11 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, TempConnection* args = reinterpret_cast(connection->private_data); connection->private_data = nullptr; std::unordered_map options = std::move(args->options); + std::unordered_map bytes_options = + std::move(args->bytes_options); + std::unordered_map int_options = std::move(args->int_options); + std::unordered_map double_options = + std::move(args->double_options); delete args; auto status = database->private_driver->ConnectionNew(connection, error); @@ -434,6 +1155,24 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, connection, option.first.c_str(), option.second.c_str(), error); if (status != ADBC_STATUS_OK) return status; } + for (const auto& option : bytes_options) { + status = database->private_driver->ConnectionSetOptionBytes( + connection, option.first.c_str(), + reinterpret_cast(option.second.data()), option.second.size(), + error); + if (status != ADBC_STATUS_OK) return status; + } + for (const auto& option : int_options) { + status = database->private_driver->ConnectionSetOptionInt( + connection, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) return status; + } + for (const auto& option : double_options) { + status = database->private_driver->ConnectionSetOptionDouble( + connection, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) return status; + } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionInit(connection, database, error); } @@ -455,8 +1194,10 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionReadPartition( - connection, serialized_partition, serialized_length, out, error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionReadPartition( + connection, serialized_partition, serialized_length, out, error), + out, connection); } AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, @@ -470,6 +1211,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, } return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); auto status = connection->private_driver->ConnectionRelease(connection, error); connection->private_driver = nullptr; return status; @@ -480,6 +1222,7 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionRollback(connection, error); } @@ -495,15 +1238,71 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const args->options[key] = value; return ADBC_STATUS_OK; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->bytes_options[key] = std::string(reinterpret_cast(value), length); + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionBytes(connection, key, value, + length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->int_options[key] = value; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionInt(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionDouble: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->double_options[key] = value; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionDouble(connection, key, value, + error); +} + AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* error) { if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementBind(statement, values, schema, error); } @@ -513,9 +1312,19 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementCancel(statement, error); +} + // XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, ArrowSchema* schema, @@ -525,6 +1334,7 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementExecutePartitions( statement, schema, partitions, rows_affected, error); } @@ -536,8 +1346,62 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, - error); + INIT_ERROR(error, statement); + WRAP_STREAM(statement->private_driver->StatementExecuteQuery(statement, out, + rows_affected, error), + out, statement); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOption(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionBytes(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionDouble(statement, key, value, + error); } AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, @@ -546,6 +1410,7 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementGetParameterSchema(statement, schema, error); } @@ -555,6 +1420,7 @@ AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); auto status = connection->private_driver->StatementNew(connection, statement, error); statement->private_driver = connection->private_driver; return status; @@ -565,6 +1431,7 @@ AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementPrepare(statement, error); } @@ -573,6 +1440,7 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); auto status = statement->private_driver->StatementRelease(statement, error); statement->private_driver = nullptr; return status; @@ -583,14 +1451,47 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionBytes(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionDouble(statement, key, value, + error); +} + AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetSqlQuery(statement, query, error); } @@ -600,6 +1501,7 @@ AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); } @@ -636,137 +1538,80 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, AdbcDriverInitFunc init_func; std::string error_message; - if (version != ADBC_VERSION_1_0_0) { - SetError(error, "Only ADBC 1.0.0 is supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - auto* driver = reinterpret_cast(raw_driver); - - if (!entrypoint) { - // Default entrypoint (see adbc.h) - entrypoint = "AdbcDriverInit"; + switch (version) { + case ADBC_VERSION_1_0_0: + case ADBC_VERSION_1_1_0: + break; + default: + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; } -#if defined(_WIN32) - - HMODULE handle = LoadLibraryExA(driver_name, NULL, 0); - if (!handle) { - error_message += driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - - std::string full_driver_name = driver_name; - full_driver_name += ".lib"; - handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); - if (!handle) { - error_message += '\n'; - error_message += full_driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - } - } - if (!handle) { - SetError(error, error_message); - return ADBC_STATUS_INTERNAL; + if (!raw_driver) { + SetError(error, "Must provide non-NULL raw_driver"); + return ADBC_STATUS_INVALID_ARGUMENT; } + auto* driver = reinterpret_cast(raw_driver); - void* load_handle = reinterpret_cast(GetProcAddress(handle, entrypoint)); - init_func = reinterpret_cast(load_handle); - if (!init_func) { - std::string message = "GetProcAddress("; - message += entrypoint; - message += ") failed: "; - GetWinError(&message); - if (!FreeLibrary(handle)) { - message += "\nFreeLibrary() failed: "; - GetWinError(&message); - } - SetError(error, message); - return ADBC_STATUS_INTERNAL; + ManagedLibrary library; + AdbcStatusCode status = library.Load(driver_name, error); + if (status != ADBC_STATUS_OK) { + // AdbcDatabaseInit tries to call this if set + driver->release = nullptr; + return status; } -#else - -#if defined(__APPLE__) - static const std::string kPlatformLibraryPrefix = "lib"; - static const std::string kPlatformLibrarySuffix = ".dylib"; -#else - static const std::string kPlatformLibraryPrefix = "lib"; - static const std::string kPlatformLibrarySuffix = ".so"; -#endif // defined(__APPLE__) - - void* handle = dlopen(driver_name, RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message = "dlopen() failed: "; - error_message += dlerror(); - - // If applicable, append the shared library prefix/extension and - // try again (this way you don't have to hardcode driver names by - // platform in the application) - const std::string driver_str = driver_name; - - std::string full_driver_name; - if (driver_str.size() < kPlatformLibraryPrefix.size() || - driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != - 0) { - full_driver_name += kPlatformLibraryPrefix; - } - full_driver_name += driver_name; - if (driver_str.size() < kPlatformLibrarySuffix.size() || - driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), - kPlatformLibrarySuffix.size(), kPlatformLibrarySuffix) != 0) { - full_driver_name += kPlatformLibrarySuffix; - } - handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message += "\ndlopen() failed: "; - error_message += dlerror(); + void* load_handle = nullptr; + if (entrypoint) { + status = library.Lookup(entrypoint, &load_handle, error); + } else { + auto name = AdbcDriverManagerDefaultEntrypoint(driver_name); + status = library.Lookup(name.c_str(), &load_handle, error); + if (status != ADBC_STATUS_OK) { + status = library.Lookup(kDefaultEntrypoint, &load_handle, error); } } - if (!handle) { - SetError(error, error_message); - // AdbcDatabaseInit tries to call this if set - driver->release = nullptr; - return ADBC_STATUS_INTERNAL; - } - void* load_handle = dlsym(handle, entrypoint); - if (!load_handle) { - std::string message = "dlsym("; - message += entrypoint; - message += ") failed: "; - message += dlerror(); - SetError(error, message); - return ADBC_STATUS_INTERNAL; + if (status != ADBC_STATUS_OK) { + library.Release(); + return status; } init_func = reinterpret_cast(load_handle); -#endif // defined(_WIN32) - - AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); + status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); if (status == ADBC_STATUS_OK) { ManagerDriverState* state = new ManagerDriverState; state->driver_release = driver->release; -#if defined(_WIN32) - state->handle = handle; -#endif // defined(_WIN32) + state->handle = std::move(library); driver->release = &ReleaseDriver; driver->private_manager = state; } else { -#if defined(_WIN32) - if (!FreeLibrary(handle)) { - std::string message = "FreeLibrary() failed: "; - GetWinError(&message); - SetError(error, message); - } -#endif // defined(_WIN32) + library.Release(); } return status; } AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void* raw_driver, struct AdbcError* error) { + constexpr std::array kSupportedVersions = { + ADBC_VERSION_1_1_0, + ADBC_VERSION_1_0_0, + }; + + if (!raw_driver) { + SetError(error, "Must provide non-NULL raw_driver"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + switch (version) { + case ADBC_VERSION_1_0_0: + case ADBC_VERSION_1_1_0: + break; + default: + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + #define FILL_DEFAULT(DRIVER, STUB) \ if (!DRIVER->STUB) { \ DRIVER->STUB = &STUB; \ @@ -777,12 +1622,20 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers return ADBC_STATUS_INTERNAL; \ } - auto result = init_func(version, raw_driver, error); + // Starting from the passed version, try each (older) version in + // succession with the underlying driver until we find one that's + // accepted. + AdbcStatusCode result = ADBC_STATUS_NOT_IMPLEMENTED; + for (const int try_version : kSupportedVersions) { + if (try_version > version) continue; + result = init_func(try_version, raw_driver, error); + if (result != ADBC_STATUS_NOT_IMPLEMENTED) break; + } if (result != ADBC_STATUS_OK) { return result; } - if (version == ADBC_VERSION_1_0_0) { + if (version >= ADBC_VERSION_1_0_0) { auto* driver = reinterpret_cast(raw_driver); CHECK_REQUIRED(driver, DatabaseNew); CHECK_REQUIRED(driver, DatabaseInit); @@ -812,6 +1665,41 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers FILL_DEFAULT(driver, StatementSetSqlQuery); FILL_DEFAULT(driver, StatementSetSubstraitPlan); } + if (version >= ADBC_VERSION_1_1_0) { + auto* driver = reinterpret_cast(raw_driver); + FILL_DEFAULT(driver, ErrorGetDetailCount); + FILL_DEFAULT(driver, ErrorGetDetail); + FILL_DEFAULT(driver, ErrorFromArrayStream); + + FILL_DEFAULT(driver, DatabaseGetOption); + FILL_DEFAULT(driver, DatabaseGetOptionBytes); + FILL_DEFAULT(driver, DatabaseGetOptionDouble); + FILL_DEFAULT(driver, DatabaseGetOptionInt); + FILL_DEFAULT(driver, DatabaseSetOptionBytes); + FILL_DEFAULT(driver, DatabaseSetOptionDouble); + FILL_DEFAULT(driver, DatabaseSetOptionInt); + + FILL_DEFAULT(driver, ConnectionCancel); + FILL_DEFAULT(driver, ConnectionGetOption); + FILL_DEFAULT(driver, ConnectionGetOptionBytes); + FILL_DEFAULT(driver, ConnectionGetOptionDouble); + FILL_DEFAULT(driver, ConnectionGetOptionInt); + FILL_DEFAULT(driver, ConnectionGetStatistics); + FILL_DEFAULT(driver, ConnectionGetStatisticNames); + FILL_DEFAULT(driver, ConnectionSetOptionBytes); + FILL_DEFAULT(driver, ConnectionSetOptionDouble); + FILL_DEFAULT(driver, ConnectionSetOptionInt); + + FILL_DEFAULT(driver, StatementCancel); + FILL_DEFAULT(driver, StatementExecuteSchema); + FILL_DEFAULT(driver, StatementGetOption); + FILL_DEFAULT(driver, StatementGetOptionBytes); + FILL_DEFAULT(driver, StatementGetOptionDouble); + FILL_DEFAULT(driver, StatementGetOptionInt); + FILL_DEFAULT(driver, StatementSetOptionBytes); + FILL_DEFAULT(driver, StatementSetOptionDouble); + FILL_DEFAULT(driver, StatementSetOptionInt); + } return ADBC_STATUS_OK; diff --git a/go/adbc/infocode_string.go b/go/adbc/infocode_string.go index 73af20c1e1..df0fd74b96 100644 --- a/go/adbc/infocode_string.go +++ b/go/adbc/infocode_string.go @@ -14,23 +14,24 @@ func _() { _ = x[InfoDriverName-100] _ = x[InfoDriverVersion-101] _ = x[InfoDriverArrowVersion-102] + _ = x[InfoDriverADBCVersion-103] } const ( _InfoCode_name_0 = "VendorNameVendorVersionVendorArrowVersion" - _InfoCode_name_1 = "DriverNameDriverVersionDriverArrowVersion" + _InfoCode_name_1 = "DriverNameDriverVersionDriverArrowVersionDriverADBCVersion" ) var ( _InfoCode_index_0 = [...]uint8{0, 10, 23, 41} - _InfoCode_index_1 = [...]uint8{0, 10, 23, 41} + _InfoCode_index_1 = [...]uint8{0, 10, 23, 41, 58} ) func (i InfoCode) String() string { switch { case i <= 2: return _InfoCode_name_0[_InfoCode_index_0[i]:_InfoCode_index_0[i+1]] - case 100 <= i && i <= 102: + case 100 <= i && i <= 103: i -= 100 return _InfoCode_name_1[_InfoCode_index_1[i]:_InfoCode_index_1[i+1]] default: diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl index 7432f001b9..fc489a4016 100644 --- a/go/adbc/pkg/_tmpl/driver.go.tmpl +++ b/go/adbc/pkg/_tmpl/driver.go.tmpl @@ -26,11 +26,22 @@ package main // #cgo CXXFLAGS: -std=c++11 -DADBC_EXPORTING // #include "../../drivermgr/adbc.h" // #include "utils.h" +// #include // #include // #include // // typedef const char cchar_t; // typedef const uint8_t cuint8_t; +// typedef const uint32_t cuint32_t; +// typedef const struct AdbcError ConstAdbcError; +// +// int {{.Prefix}}ArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +// int {{.Prefix}}ArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); +// const char* {{.Prefix}}ArrayStreamGetLastError(struct ArrowArrayStream*); +// void {{.Prefix}}ArrayStreamRelease(struct ArrowArrayStream*); +// +// int {{.Prefix}}ArrayStreamGetSchemaTrampoline(struct ArrowArrayStream*, struct ArrowSchema*); +// int {{.Prefix}}ArrayStreamGetNextTrampoline(struct ArrowArrayStream*, struct ArrowArray*); // // void releasePartitions(struct AdbcPartitions* partitions); // @@ -48,6 +59,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/cdata" + "github.com/apache/arrow/go/v13/arrow/memory" "github.com/apache/arrow/go/v13/arrow/memory/mallocator" ) @@ -74,14 +86,63 @@ func setErr(err *C.struct_AdbcError, format string, vals ...interface{}) { err.release = (*[0]byte)(C.{{.Prefix}}_release_error) } +func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) { + if err == nil { + return + } + + if err.vendor_code != C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA { + setErr(err, adbcError.Msg) + return + } + + cErrPtr := C.malloc(C.sizeof_struct_{{.Prefix}}Error) + cErr := (*C.struct_{{.Prefix}}Error)(cErrPtr) + cErr.message = C.CString(adbcError.Msg) + err.message = cErr.message + err.release = (*[0]byte)(C.{{.Prefix}}ReleaseErrWithDetails) + err.private_data = cErrPtr + + numDetails := len(adbcError.Details) + if numDetails > 0 { + cErr.keys = (**C.cchar_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cchar_t)(nil))))) + cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil))))) + cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t)) + + keys := fromCArr[*C.cchar_t](cErr.keys, numDetails) + values := fromCArr[*C.cuint8_t](cErr.values, numDetails) + lengths := fromCArr[C.size_t](cErr.lengths, numDetails) + + for i, detail := range adbcError.Details { + keys[i] = C.CString(detail.Key()) + bytes, err := detail.Serialize() + if err != nil { + msg := err.Error() + values[i] = (*C.cuint8_t)(unsafe.Pointer(C.CString(msg))) + lengths[i] = C.size_t(len(msg)) + } else { + values[i] = (*C.cuint8_t)(C.malloc(C.size_t(len(bytes)))) + sink := fromCArr[byte]((*byte)(values[i]), len(bytes)) + copy(sink, bytes) + lengths[i] = C.size_t(len(bytes)) + } + } + } else { + cErr.keys = nil + cErr.values = nil + cErr.lengths = nil + } + cErr.count = C.int(numDetails) +} + func errToAdbcErr(adbcerr *C.struct_AdbcError, err error) adbc.Status { - if adbcerr == nil || err == nil { + if err == nil { return adbc.StatusOK } var adbcError adbc.Error if errors.As(err, &adbcError) { - setErr(adbcerr, adbcError.Msg) + setErrWithDetails(adbcerr, adbcError) return adbcError.Code } @@ -119,6 +180,45 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T { return cgo.Handle((uintptr)(*hptr)).Value().(*T) } +func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode { + lenWithTerminator := C.size_t(len(val) + 1) + if lenWithTerminator <= *length { + sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length)) + copy(sink, val) + sink[lenWithTerminator] = 0 + } + *length = lenWithTerminator + return C.ADBC_STATUS_OK +} + +func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode { + if C.size_t(len(val)) <= *length { + sink := fromCArr[byte]((*byte)(out), int(*length)) + copy(sink, val) + } + *length = C.size_t(len(val)) + return C.ADBC_STATUS_OK +} + +type cancellableContext struct { + ctx context.Context + cancel context.CancelFunc +} + +func (c *cancellableContext) newContext() context.Context { + c.cancelContext() + c.ctx, c.cancel = context.WithCancel(context.Background()) + return c.ctx +} + +func (c *cancellableContext) cancelContext() { + if c.cancel != nil { + c.cancel() + } + c.ctx = nil + c.cancel = nil +} + func checkDBAlloc(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) @@ -148,48 +248,243 @@ func checkDBInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname strin return cdb } +// Custom ArrowArrayStream export to support ADBC error data in ArrowArrayStream + +type cArrayStream struct { + rdr array.RecordReader + // Must be C-allocated + adbcErr *C.struct_AdbcError + status C.AdbcStatusCode +} + +func (cStream *cArrayStream) maybeError() C.int { + err := cStream.rdr.Err() + if err != nil { + if cStream.adbcErr != nil { + C.{{.Prefix}}errRelease(cStream.adbcErr) + } else { + cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + } + cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA + cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) + switch adbc.Status(cStream.status) { + case adbc.StatusUnknown: + return C.EIO + case adbc.StatusNotImplemented: + return C.ENOTSUP + case adbc.StatusNotFound: + return C.ENOENT + case adbc.StatusAlreadyExists: + return C.EEXIST + case adbc.StatusInvalidArgument: + return C.EINVAL + case adbc.StatusInvalidState: + return C.EINVAL + case adbc.StatusInvalidData: + return C.EIO + case adbc.StatusIntegrity: + return C.EIO + case adbc.StatusInternal: + return C.EIO + case adbc.StatusIO: + return C.EIO + case adbc.StatusCancelled: + return C.ECANCELED + case adbc.StatusTimeout: + return C.ETIMEDOUT + case adbc.StatusUnauthenticated: + return C.EACCES + case adbc.StatusUnauthorized: + return C.EACCES + default: + return C.EIO + } + } + return 0 +} + +//export {{.Prefix}}ArrayStreamGetLastError +func {{.Prefix}}ArrayStreamGetLastError(stream *C.struct_ArrowArrayStream) *C.cchar_t { + if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.adbcErr != nil { + return cStream.adbcErr.message + } + return nil +} + +//export {{.Prefix}}ArrayStreamGetNext +func {{.Prefix}}ArrayStreamGetNext(stream *C.struct_ArrowArrayStream, array *C.struct_ArrowArray) C.int { + if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.rdr.Next() { + cdata.ExportArrowRecordBatch(cStream.rdr.Record(), toCdataArray(array), nil) + return 0; + } + array.release = nil + array.private_data = nil + return cStream.maybeError() +} + +//export {{.Prefix}}ArrayStreamGetSchema +func {{.Prefix}}ArrayStreamGetSchema(stream *C.struct_ArrowArrayStream, schema *C.struct_ArrowSchema) C.int { + if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + s := cStream.rdr.Schema() + if s == nil { + return cStream.maybeError() + } + cdata.ExportArrowSchema(s, toCdataSchema(schema)) + return 0 +} + +//export {{.Prefix}}ArrayStreamRelease +func {{.Prefix}}ArrayStreamRelease(stream *C.struct_ArrowArrayStream) { + if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) { + return + } + h := (*(*cgo.Handle)(stream.private_data)) + + cStream := h.Value().(*cArrayStream) + cStream.rdr.Release() + if cStream.adbcErr != nil { + C.{{.Prefix}}errRelease(cStream.adbcErr) + C.free(unsafe.Pointer(cStream.adbcErr)) + } + C.free(unsafe.Pointer(stream.private_data)) + stream.private_data = nil + h.Delete() + runtime.GC() +} + +//export {{.Prefix}}ErrorFromArrayStream +func {{.Prefix}}ErrorFromArrayStream(stream *C.struct_ArrowArrayStream, status *C.AdbcStatusCode) (*C.struct_AdbcError) { + if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if status != nil { + *status = cStream.status + } + return cStream.adbcErr +} + +func exportRecordReader(rdr array.RecordReader, stream *C.struct_ArrowArrayStream) { + cStream := &cArrayStream{rdr: rdr, status: C.ADBC_STATUS_OK} + stream.get_last_error = (*[0]byte)(C.{{.Prefix}}ArrayStreamGetLastError) + stream.get_next = (*[0]byte)(C.{{.Prefix}}ArrayStreamGetNextTrampoline) + stream.get_schema = (*[0]byte)(C.{{.Prefix}}ArrayStreamGetSchemaTrampoline) + stream.release = (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) + hndl := cgo.NewHandle(cStream) + stream.private_data = createHandle(hndl) + rdr.Retain() +} + type cDatabase struct { opts map[string]string db adbc.Database } -//export {{.Prefix}}DatabaseNew -func {{.Prefix}}DatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export {{.Prefix}}DatabaseGetOption +func {{.Prefix}}DatabaseGetOption(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseNew", e) + code = poison(err, "AdbcDatabaseGetOption", e) } }() - if atomic.LoadInt32(&globalPoison) != 0 { - setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") - return C.ADBC_STATUS_INTERNAL + cdb := checkDBInit(db, err, "AdbcDatabaseGetOption") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE } - if db.private_data != nil { - setErr(err, "AdbcDatabaseNew: database already allocated") + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export {{.Prefix}}DatabaseGetOptionBytes +func {{.Prefix}}DatabaseGetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionBytes") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - dbobj := &cDatabase{opts: make(map[string]string)} - hndl := cgo.NewHandle(dbobj) - db.private_data = createHandle(hndl) - return C.ADBC_STATUS_OK + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) } -//export {{.Prefix}}DatabaseSetOption -func {{.Prefix}}DatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export {{.Prefix}}DatabaseGetOptionDouble +func {{.Prefix}}DatabaseGetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseSetOption", e) + code = poison(err, "AdbcDatabaseGetOptionDouble", e) } }() - if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionDouble") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - cdb := getFromHandle[cDatabase](db.private_data) - k, v := C.GoString(key), C.GoString(value) - cdb.opts[k] = v + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return C.ADBC_STATUS_OK + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export {{.Prefix}}DatabaseGetOptionInt +func {{.Prefix}}DatabaseGetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export {{.Prefix}}DatabaseInit @@ -218,6 +513,27 @@ func {{.Prefix}}DatabaseInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) return C.ADBC_STATUS_OK } +//export {{.Prefix}}DatabaseNew +func {{.Prefix}}DatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseNew", e) + } + }() + if atomic.LoadInt32(&globalPoison) != 0 { + setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") + return C.ADBC_STATUS_INTERNAL + } + if db.private_data != nil { + setErr(err, "AdbcDatabaseNew: database already allocated") + return C.ADBC_STATUS_INVALID_STATE + } + dbobj := &cDatabase{opts: make(map[string]string)} + hndl := cgo.NewHandle(dbobj) + db.private_data = createHandle(hndl) + return C.ADBC_STATUS_OK +} + //export {{.Prefix}}DatabaseRelease func {{.Prefix}}DatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -246,7 +562,99 @@ func {{.Prefix}}DatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcErr return C.ADBC_STATUS_OK } +//export {{.Prefix}}DatabaseSetOption +func {{.Prefix}}DatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOption", e) + } + }() + if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + return C.ADBC_STATUS_INVALID_STATE + } + cdb := getFromHandle[cDatabase](db.private_data) + + k, v := C.GoString(key), C.GoString(value) + if cdb.db != nil { + opts, ok := cdb.db.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(k, v))) + } else { + cdb.opts[k] = v + } + + return C.ADBC_STATUS_OK +} + +//export {{.Prefix}}DatabaseSetOptionBytes +func {{.Prefix}}DatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionBytes") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export {{.Prefix}}DatabaseSetOptionDouble +func {{.Prefix}}DatabaseSetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionDouble", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionDouble") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export {{.Prefix}}DatabaseSetOptionInt +func {{.Prefix}}DatabaseSetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) +} + type cConn struct { + cancellableContext + cnxn adbc.Connection initArgs map[string]string } @@ -280,6 +688,102 @@ func checkConnInit(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname return conn } +//export {{.Prefix}}ConnectionGetOption +func {{.Prefix}}ConnectionGetOption(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOption", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOption") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export {{.Prefix}}ConnectionGetOptionBytes +func {{.Prefix}}ConnectionGetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export {{.Prefix}}ConnectionGetOptionDouble +func {{.Prefix}}ConnectionGetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export {{.Prefix}}ConnectionGetOptionInt +func {{.Prefix}}ConnectionGetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + //export {{.Prefix}}ConnectionNew func {{.Prefix}}ConnectionNew(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -323,13 +827,75 @@ func {{.Prefix}}ConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.c return C.ADBC_STATUS_OK } - opts, ok := conn.cnxn.(adbc.PostInitOptions) + opts, ok := conn.cnxn.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcConnectionSetOption: not supported post-init") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))) +} + +//export {{.Prefix}}ConnectionSetOptionBytes +func {{.Prefix}}ConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export {{.Prefix}}ConnectionSetOptionDouble +func {{.Prefix}}ConnectionSetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export {{.Prefix}}ConnectionSetOptionInt +func {{.Prefix}}ConnectionSetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) if !ok { - setErr(err, "AdbcConnectionSetOption: not supported post-init") + setErr(err, "AdbcConnectionSetOptionInt: options are not supported") return C.ADBC_STATUS_NOT_IMPLEMENTED } - rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val))) - return C.AdbcStatusCode(rawCode) + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export {{.Prefix}}ConnectionInit @@ -392,8 +958,9 @@ func {{.Prefix}}ConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_A conn := h.Value().(*cConn) defer func() { + conn.cancelContext() conn.cnxn = nil - C.free(unsafe.Pointer(cnxn.private_data)) + C.free(cnxn.private_data) cnxn.private_data = nil h.Delete() // manually trigger GC for two reasons: @@ -430,26 +997,19 @@ func toCdataArray(ptr *C.struct_ArrowArray) *cdata.CArrowArray { return (*cdata.CArrowArray)(unsafe.Pointer(ptr)) } -//export {{.Prefix}}ConnectionGetInfo -func {{.Prefix}}ConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.uint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export {{.Prefix}}ConnectionCancel +func {{.Prefix}}ConnectionCancel(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcConnectionGetInfo", e) + code = poison(err, "AdbcConnectionCancel", e) } }() - conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + conn := checkConnInit(cnxn, err, "AdbcConnectionCancel") if conn == nil { return C.ADBC_STATUS_INVALID_STATE } - infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) - rdr, e := conn.cnxn.GetInfo(context.Background(), infoCodes) - if e != nil { - return C.AdbcStatusCode(errToAdbcErr(err, e)) - } - - defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + conn.cancelContext() return C.ADBC_STATUS_OK } @@ -477,6 +1037,29 @@ func toStrSlice(in **C.cchar_t) []string { return out } +//export {{.Prefix}}ConnectionGetInfo +func {{.Prefix}}ConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetInfo", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) + rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + //export {{.Prefix}}ConnectionGetObjects func {{.Prefix}}ConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, catalog, dbSchema, tableName *C.cchar_t, tableType **C.cchar_t, columnName *C.cchar_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { @@ -490,12 +1073,67 @@ func {{.Prefix}}ConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetObjects(context.Background(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + rdr, e := conn.cnxn.GetObjects(conn.newContext(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export {{.Prefix}}ConnectionGetStatistics +func {{.Prefix}}ConnectionGetStatistics(cnxn *C.struct_AdbcConnection, catalog, dbSchema, tableName *C.cchar_t, approximate C.char, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatistics(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), int(approximate) != 0) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export {{.Prefix}}ConnectionGetStatisticNames +func {{.Prefix}}ConnectionGetStatisticNames(cnxn *C.struct_AdbcConnection, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatisticNames(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -511,7 +1149,7 @@ func {{.Prefix}}ConnectionGetTableSchema(cnxn *C.struct_AdbcConnection, catalog, return C.ADBC_STATUS_INVALID_STATE } - sc, e := conn.cnxn.GetTableSchema(context.Background(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) + sc, e := conn.cnxn.GetTableSchema(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -531,12 +1169,12 @@ func {{.Prefix}}ConnectionGetTableTypes(cnxn *C.struct_AdbcConnection, out *C.st return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetTableTypes(context.Background()) + rdr, e := conn.cnxn.GetTableTypes(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -552,12 +1190,12 @@ func {{.Prefix}}ConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialize return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.ReadPartition(context.Background(), fromCArr[byte](serialized, int(serializedLen))) + rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen))) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -573,7 +1211,7 @@ func {{.Prefix}}ConnectionCommit(cnxn *C.struct_AdbcConnection, err *C.struct_Ad return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(conn.newContext()))) } //export {{.Prefix}}ConnectionRollback @@ -588,25 +1226,137 @@ func {{.Prefix}}ConnectionRollback(cnxn *C.struct_AdbcConnection, err *C.struct_ return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(conn.newContext()))) +} + +type cStmt struct { + cancellableContext + + stmt adbc.Statement } -func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) adbc.Statement { +func checkStmtAlloc(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) - return nil + return false } if stmt == nil { setErr(err, "%s: statement not allocated", fname) - return nil + return false } - if stmt.private_data == nil { - setErr(err, "%s: statement not initialized", fname) + setErr(err, "%s: statement not allocated", fname) + return false + } + return true +} + +func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) *cStmt { + if !checkStmtAlloc(stmt, err, fname) { + return nil + } + cStmt := getFromHandle[cStmt](stmt.private_data) + if cStmt.stmt == nil { + setErr(err, "%s: statement not allocated", fname) return nil } + return cStmt +} + +//export {{.Prefix}}StatementGetOption +func {{.Prefix}}StatementGetOption(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOption", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOption") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export {{.Prefix}}StatementGetOptionBytes +func {{.Prefix}}StatementGetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export {{.Prefix}}StatementGetOptionDouble +func {{.Prefix}}StatementGetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export {{.Prefix}}StatementGetOptionInt +func {{.Prefix}}StatementGetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return (*(*cgo.Handle)(stmt.private_data)).Value().(adbc.Statement) + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export {{.Prefix}}StatementNew @@ -635,8 +1385,8 @@ func {{.Prefix}}StatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcS return C.AdbcStatusCode(errToAdbcErr(err, e)) } - h := cgo.NewHandle(st) - stmt.private_data = createHandle(h) + hndl := cgo.NewHandle(&cStmt{stmt: st}) + stmt.private_data = createHandle(hndl) return C.ADBC_STATUS_OK } @@ -651,31 +1401,46 @@ func {{.Prefix}}StatementRelease(stmt *C.struct_AdbcStatement, err *C.struct_Adb setErr(err, "AdbcStatementRelease: Go panicked, driver is in unknown state") return C.ADBC_STATUS_INTERNAL } - if stmt == nil { - setErr(err, "AdbcStatementRelease: statement not allocated") + if !checkStmtAlloc(stmt, err, "AdbcStatementRelease") { return C.ADBC_STATUS_INVALID_STATE } + h := (*(*cgo.Handle)(stmt.private_data)) - if stmt.private_data == nil { - setErr(err, "AdbcStatementRelease: statement not initialized") - return C.ADBC_STATUS_INVALID_STATE + st := h.Value().(*cStmt) + defer func() { + st.cancelContext() + st.stmt = nil + C.free(stmt.private_data) + stmt.private_data = nil + h.Delete() + // manually trigger GC for two reasons: + // 1. ASAN expects the release callback to be called before + // the process ends, but GC is not deterministic. So by manually + // triggering the GC we ensure the release callback gets called. + // 2. Creates deterministic GC behavior by all Release functions + // triggering a garbage collection + runtime.GC() + }() + if st.stmt == nil { + return C.ADBC_STATUS_OK } + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Close())) +} - h := (*(*cgo.Handle)(stmt.private_data)) - st := h.Value().(adbc.Statement) - C.free(stmt.private_data) - stmt.private_data = nil +//export {{.Prefix}}StatementCancel +func {{.Prefix}}StatementCancel(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementCancel", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementCancel") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } - e := st.Close() - h.Delete() - // manually trigger GC for two reasons: - // 1. ASAN expects the release callback to be called before - // the process ends, but GC is not deterministic. So by manually - // triggering the GC we ensure the release callback gets called. - // 2. Creates deterministic GC behavior by all Release functions - // triggering a garbage collection - runtime.GC() - return C.AdbcStatusCode(errToAdbcErr(err, e)) + st.cancelContext() + return C.ADBC_STATUS_OK } //export {{.Prefix}}StatementPrepare @@ -690,7 +1455,7 @@ func {{.Prefix}}StatementPrepare(stmt *C.struct_AdbcStatement, err *C.struct_Adb return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.Prepare(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Prepare(st.newContext()))) } //export {{.Prefix}}StatementExecuteQuery @@ -706,7 +1471,7 @@ func {{.Prefix}}StatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struc } if out == nil { - n, e := st.ExecuteUpdate(context.Background()) + n, e := st.stmt.ExecuteUpdate(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -715,7 +1480,7 @@ func {{.Prefix}}StatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struc *affected = C.int64_t(n) } } else { - rdr, n, e := st.ExecuteQuery(context.Background()) + rdr, n, e := st.stmt.ExecuteQuery(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -725,8 +1490,35 @@ func {{.Prefix}}StatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struc } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) + } + return C.ADBC_STATUS_OK +} + +//export {{.Prefix}}StatementExecuteSchema +func {{.Prefix}}StatementExecuteSchema(stmt *C.struct_AdbcStatement, schema *C.struct_ArrowSchema, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementExecuteQuery", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementExecuteQuery") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + es, ok := st.stmt.(adbc.StatementExecuteSchema) + if !ok { + setErr(err, "AdbcStatementExecuteSchema: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + sc, e := es.ExecuteSchema(st.newContext()) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) } + + cdata.ExportArrowSchema(sc, toCdataSchema(schema)) return C.ADBC_STATUS_OK } @@ -742,7 +1534,7 @@ func {{.Prefix}}StatementSetSqlQuery(stmt *C.struct_AdbcStatement, query *C.ccha return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSqlQuery(C.GoString(query)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSqlQuery(C.GoString(query)))) } //export {{.Prefix}}StatementSetSubstraitPlan @@ -757,7 +1549,7 @@ func {{.Prefix}}StatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C. return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) } //export {{.Prefix}}StatementBind @@ -780,7 +1572,7 @@ func {{.Prefix}}StatementBind(stmt *C.struct_AdbcStatement, values *C.struct_Arr } defer rec.Release() - return C.AdbcStatusCode(errToAdbcErr(err, st.Bind(context.Background(), rec))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Bind(st.newContext(), rec))) } //export {{.Prefix}}StatementBindStream @@ -799,7 +1591,7 @@ func {{.Prefix}}StatementBindStream(stmt *C.struct_AdbcStatement, stream *C.stru if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } - return C.AdbcStatusCode(errToAdbcErr(err, st.BindStream(context.Background(), rdr.(array.RecordReader)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.BindStream(st.newContext(), rdr.(array.RecordReader)))) } //export {{.Prefix}}StatementGetParameterSchema @@ -814,7 +1606,7 @@ func {{.Prefix}}StatementGetParameterSchema(stmt *C.struct_AdbcStatement, schema return C.ADBC_STATUS_INVALID_STATE } - sc, e := st.GetParameterSchema() + sc, e := st.stmt.GetParameterSchema() if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -835,7 +1627,70 @@ func {{.Prefix}}StatementSetOption(stmt *C.struct_AdbcStatement, key, value *C.c return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetOption(C.GoString(key), C.GoString(value)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetOption(C.GoString(key), C.GoString(value)))) +} + +//export {{.Prefix}}StatementSetOptionBytes +func {{.Prefix}}StatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export {{.Prefix}}StatementSetOptionDouble +func {{.Prefix}}StatementSetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export {{.Prefix}}StatementSetOptionInt +func {{.Prefix}}StatementSetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export releasePartitions @@ -864,7 +1719,7 @@ func {{.Prefix}}StatementExecutePartitions(stmt *C.struct_AdbcStatement, schema return C.ADBC_STATUS_INVALID_STATE } - sc, part, n, e := st.ExecutePartitions(context.Background()) + sc, part, n, e := st.stmt.ExecutePartitions(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -909,13 +1764,20 @@ func {{.Prefix}}StatementExecutePartitions(stmt *C.struct_AdbcStatement, schema //export {{.Prefix}}DriverInit func {{.Prefix}}DriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcError) C.AdbcStatusCode { - if version != C.ADBC_VERSION_1_0_0 { - setErr(err, "Only version %d supported, got %d", int(C.ADBC_VERSION_1_0_0), int(version)) + driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) + + switch version { + case C.ADBC_VERSION_1_0_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE) + memory.Set(sink, 0) + case C.ADBC_VERSION_1_1_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE) + memory.Set(sink, 0) + default: + setErr(err, "Only version 1.0.0/1.1.0 supported, got %d", int(version)) return C.ADBC_STATUS_NOT_IMPLEMENTED } - driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) - C.memset(unsafe.Pointer(driver), 0, C.sizeof_struct_AdbcDriver) driver.DatabaseInit = (*[0]byte)(C.{{.Prefix}}DatabaseInit) driver.DatabaseNew = (*[0]byte)(C.{{.Prefix}}DatabaseNew) driver.DatabaseRelease = (*[0]byte)(C.{{.Prefix}}DatabaseRelease) @@ -945,6 +1807,41 @@ func {{.Prefix}}DriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcE driver.StatementGetParameterSchema = (*[0]byte)(C.{{.Prefix}}StatementGetParameterSchema) driver.StatementPrepare = (*[0]byte)(C.{{.Prefix}}StatementPrepare) + if version == C.ADBC_VERSION_1_1_0 { + driver.ErrorGetDetailCount = (*[0]byte)(C.{{.Prefix}}ErrorGetDetailCount) + driver.ErrorGetDetail = (*[0]byte)(C.{{.Prefix}}ErrorGetDetail) + driver.ErrorFromArrayStream = (*[0]byte)(C.{{.Prefix}}ErrorFromArrayStream) + + driver.DatabaseGetOption = (*[0]byte)(C.{{.Prefix}}DatabaseGetOption) + driver.DatabaseGetOptionBytes = (*[0]byte)(C.{{.Prefix}}DatabaseGetOptionBytes) + driver.DatabaseGetOptionDouble = (*[0]byte)(C.{{.Prefix}}DatabaseGetOptionDouble) + driver.DatabaseGetOptionInt = (*[0]byte)(C.{{.Prefix}}DatabaseGetOptionInt) + driver.DatabaseSetOptionBytes = (*[0]byte)(C.{{.Prefix}}DatabaseSetOptionBytes) + driver.DatabaseSetOptionDouble = (*[0]byte)(C.{{.Prefix}}DatabaseSetOptionDouble) + driver.DatabaseSetOptionInt = (*[0]byte)(C.{{.Prefix}}DatabaseSetOptionInt) + + driver.ConnectionCancel = (*[0]byte)(C.{{.Prefix}}ConnectionCancel) + driver.ConnectionGetOption = (*[0]byte)(C.{{.Prefix}}ConnectionGetOption) + driver.ConnectionGetOptionBytes = (*[0]byte)(C.{{.Prefix}}ConnectionGetOptionBytes) + driver.ConnectionGetOptionDouble = (*[0]byte)(C.{{.Prefix}}ConnectionGetOptionDouble) + driver.ConnectionGetOptionInt = (*[0]byte)(C.{{.Prefix}}ConnectionGetOptionInt) + driver.ConnectionGetStatistics = (*[0]byte)(C.{{.Prefix}}ConnectionGetStatistics) + driver.ConnectionGetStatisticNames = (*[0]byte)(C.{{.Prefix}}ConnectionGetStatisticNames) + driver.ConnectionSetOptionBytes = (*[0]byte)(C.{{.Prefix}}ConnectionSetOptionBytes) + driver.ConnectionSetOptionDouble = (*[0]byte)(C.{{.Prefix}}ConnectionSetOptionDouble) + driver.ConnectionSetOptionInt = (*[0]byte)(C.{{.Prefix}}ConnectionSetOptionInt) + + driver.StatementCancel = (*[0]byte)(C.{{.Prefix}}StatementCancel) + driver.StatementExecuteSchema = (*[0]byte)(C.{{.Prefix}}StatementExecuteSchema) + driver.StatementGetOption = (*[0]byte)(C.{{.Prefix}}StatementGetOption) + driver.StatementGetOptionBytes = (*[0]byte)(C.{{.Prefix}}StatementGetOptionBytes) + driver.StatementGetOptionDouble = (*[0]byte)(C.{{.Prefix}}StatementGetOptionDouble) + driver.StatementGetOptionInt = (*[0]byte)(C.{{.Prefix}}StatementGetOptionInt) + driver.StatementSetOptionBytes = (*[0]byte)(C.{{.Prefix}}StatementSetOptionBytes) + driver.StatementSetOptionDouble = (*[0]byte)(C.{{.Prefix}}StatementSetOptionDouble) + driver.StatementSetOptionInt = (*[0]byte)(C.{{.Prefix}}StatementSetOptionInt) + } + return C.ADBC_STATUS_OK } diff --git a/go/adbc/pkg/_tmpl/utils.c.tmpl b/go/adbc/pkg/_tmpl/utils.c.tmpl index 38222875fd..195e788cfa 100644 --- a/go/adbc/pkg/_tmpl/utils.c.tmpl +++ b/go/adbc/pkg/_tmpl/utils.c.tmpl @@ -17,7 +17,7 @@ // clang-format off //go:build driverlib -// clang-format on +// clang-format on #include "utils.h" @@ -33,51 +33,142 @@ void {{.Prefix}}_release_error(struct AdbcError* error) { error->release = NULL; } -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return {{.Prefix}}DatabaseNew(database, error); +void {{.Prefix}}ReleaseErrWithDetails(struct AdbcError* error) { + if (!error || error->release != {{.Prefix}}ReleaseErrWithDetails || + !error->private_data) { + return; + } + + struct {{.Prefix}}Error* details = + (struct {{.Prefix}}Error*) error->private_data; + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + free(details->keys); + free(details->values); + free(details->lengths); + free(details); + + free(error->message); + error->message = NULL; + error->release = NULL; + error->private_data = NULL; } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return {{.Prefix}}DatabaseSetOption(database, key, value, error); +int {{.Prefix}}ErrorGetDetailCount(const struct AdbcError* error) { + if (!error || error->release != {{.Prefix}}ReleaseErrWithDetails || + !error->private_data) { + return 0; + } + + return ((struct {{.Prefix}}Error*) error->private_data)->count; +} + +struct AdbcErrorDetail {{.Prefix}}ErrorGetDetail(const struct AdbcError* error, + int index) { + if (!error || error->release != {{.Prefix}}ReleaseErrWithDetails || + !error->private_data) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct {{.Prefix}}Error* details = (struct {{.Prefix}}Error*) error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index] + }; +} + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return {{.Prefix}}ErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return {{.Prefix}}ErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return {{.Prefix}}ErrorFromArrayStream(stream, status); +} + +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return {{.Prefix}}DatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return {{.Prefix}}DatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return {{.Prefix}}DatabaseGetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return {{.Prefix}}DatabaseGetOptionInt(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return {{.Prefix}}DatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return {{.Prefix}}DatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return {{.Prefix}}DatabaseRelease(database, error); } -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, - struct AdbcError* error) { - return {{.Prefix}}ConnectionNew(connection, error); +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return {{.Prefix}}DatabaseSetOption(database, key, value, error); } -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, - const char* value, struct AdbcError* error) { - return {{.Prefix}}ConnectionSetOption(connection, key, value, error); +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return {{.Prefix}}DatabaseSetOptionBytes(database, key, value, length, error); } -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, - struct AdbcDatabase* database, - struct AdbcError* error) { - return {{.Prefix}}ConnectionInit(connection, database, error); +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return {{.Prefix}}DatabaseSetOptionDouble(database, key, value, error); } -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, - struct AdbcError* error) { - return {{.Prefix}}ConnectionRelease(connection, error); +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return {{.Prefix}}DatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return {{.Prefix}}ConnectionCancel(connection, error); +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + return {{.Prefix}}ConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); - return {{.Prefix}}ConnectionGetInfo(connection, info_codes, info_codes_length, out, error); + return {{.Prefix}}ConnectionGetInfo(connection, info_codes, info_codes_length, + out, error); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -88,7 +179,46 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return {{.Prefix}}ConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_type, column_name, out, error); + table_type, column_name, out, error); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return {{.Prefix}}ConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return {{.Prefix}}ConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return {{.Prefix}}ConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return {{.Prefix}}ConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return {{.Prefix}}ConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return {{.Prefix}}ConnectionGetStatisticNames(connection, out, error); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -108,6 +238,17 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, return {{.Prefix}}ConnectionGetTableTypes(connection, out, error); } +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return {{.Prefix}}ConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + return {{.Prefix}}ConnectionNew(connection, error); +} + AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, const uint8_t* serialized_partition, size_t serialized_length, @@ -118,9 +259,9 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, serialized_length, out, error); } -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, - struct AdbcError* error) { - return {{.Prefix}}ConnectionCommit(connection, error); +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + return {{.Prefix}}ConnectionRelease(connection, error); } AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, @@ -128,39 +269,32 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return {{.Prefix}}ConnectionRollback(connection, error); } -AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, - struct AdbcStatement* statement, - struct AdbcError* error) { - return {{.Prefix}}StatementNew(connection, statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, - struct AdbcError* error) { - return {{.Prefix}}StatementRelease(statement, error); +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error) { + return {{.Prefix}}ConnectionSetOption(connection, key, value, error); } -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, - struct ArrowArrayStream* out, - int64_t* rows_affected, - struct AdbcError* error) { - if (out) memset(out, 0, sizeof(*out)); - return {{.Prefix}}StatementExecuteQuery(statement, out, rows_affected, error); +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return {{.Prefix}}ConnectionSetOptionBytes(connection, key, value, length, error); } -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, - struct AdbcError* error) { - return {{.Prefix}}StatementPrepare(statement, error); +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return {{.Prefix}}ConnectionSetOptionDouble(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, - const char* query, struct AdbcError* error) { - return {{.Prefix}}StatementSetSqlQuery(statement, query, error); +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return {{.Prefix}}ConnectionSetOptionInt(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, - const uint8_t* plan, size_t length, - struct AdbcError* error) { - return {{.Prefix}}StatementSetSubstraitPlan(statement, plan, length, error); +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return {{.Prefix}}StatementCancel(statement, error); } AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, @@ -175,6 +309,56 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return {{.Prefix}}StatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + if (partitions) memset(partitions, 0, sizeof(*partitions)); + return {{.Prefix}}StatementExecutePartitions(statement, schema, partitions, + rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + if (out) memset(out, 0, sizeof(*out)); + return {{.Prefix}}StatementExecuteQuery(statement, out, rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + return {{.Prefix}}StatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return {{.Prefix}}StatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return {{.Prefix}}StatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return {{.Prefix}}StatementGetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + return {{.Prefix}}StatementGetOptionInt(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -182,20 +366,54 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, return {{.Prefix}}StatementGetParameterSchema(statement, schema, error); } +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + return {{.Prefix}}StatementNew(connection, statement, error); +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + return {{.Prefix}}StatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + return {{.Prefix}}StatementRelease(statement, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + return {{.Prefix}}StatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + return {{.Prefix}}StatementSetSubstraitPlan(statement, plan, length, error); +} + AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error) { return {{.Prefix}}StatementSetOption(statement, key, value, error); } -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, - struct ArrowSchema* schema, - struct AdbcPartitions* partitions, - int64_t* rows_affected, - struct AdbcError* error) { - if (schema) memset(schema, 0, sizeof(*schema)); - if (partitions) memset(partitions, 0, sizeof(*partitions)); - return {{.Prefix}}StatementExecutePartitions(statement, schema, partitions, rows_affected, - error); +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return {{.Prefix}}StatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return {{.Prefix}}StatementSetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + return {{.Prefix}}StatementSetOptionInt(statement, key, value, error); } ADBC_EXPORT @@ -203,6 +421,23 @@ AdbcStatusCode AdbcDriverInit(int version, void* driver, struct AdbcError* error return {{.Prefix}}DriverInit(version, driver, error); } +int {{.Prefix}}ArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +int {{.Prefix}}ArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); + +int {{.Prefix}}ArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return {{.Prefix}}ArrayStreamGetSchema(stream, out); +} + +int {{.Prefix}}ArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return {{.Prefix}}ArrayStreamGetNext(stream, out); +} + #ifdef __cplusplus } #endif diff --git a/go/adbc/pkg/_tmpl/utils.h.tmpl b/go/adbc/pkg/_tmpl/utils.h.tmpl index a8c7d973f6..ce3dba9dbc 100644 --- a/go/adbc/pkg/_tmpl/utils.h.tmpl +++ b/go/adbc/pkg/_tmpl/utils.h.tmpl @@ -24,32 +24,62 @@ #include "../../drivermgr/adbc.h" #include -AdbcStatusCode {{.Prefix}}DatabaseNew(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}DatabaseSetOption(struct AdbcDatabase* db, const char* key, const char* value, struct AdbcError* err); +struct AdbcError* {{.Prefix}}ErrorFromArrayStream(struct ArrowArrayStream*, AdbcStatusCode*); +AdbcStatusCode {{.Prefix}}DatabaseGetOption(struct AdbcDatabase*, const char*, char*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}DatabaseGetOptionBytes(struct AdbcDatabase*, const char*, uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}DatabaseGetOptionDouble(struct AdbcDatabase*, const char*, double*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}DatabaseGetOptionInt(struct AdbcDatabase*, const char*, int64_t*, struct AdbcError*); AdbcStatusCode {{.Prefix}}DatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}DatabaseNew(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode {{.Prefix}}DatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionSetOption(struct AdbcConnection* cnxn, const char* key, const char* val, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionInit(struct AdbcConnection* cnxn, struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionRelease(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionGetInfo(struct AdbcConnection* cnxn, uint32_t* codes, size_t len, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}DatabaseSetOption(struct AdbcDatabase* db, const char* key, const char* value, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}DatabaseSetOptionBytes(struct AdbcDatabase*, const char*, const uint8_t*, size_t, struct AdbcError*); +AdbcStatusCode {{.Prefix}}DatabaseSetOptionDouble(struct AdbcDatabase*, const char*, double, struct AdbcError*); +AdbcStatusCode {{.Prefix}}DatabaseSetOptionInt(struct AdbcDatabase*, const char*, int64_t, struct AdbcError*); + +AdbcStatusCode {{.Prefix}}ConnectionCancel(struct AdbcConnection*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionCommit(struct AdbcConnection* cnxn, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionGetInfo(struct AdbcConnection* cnxn, const uint32_t* codes, size_t len, struct ArrowArrayStream* out, struct AdbcError* err); AdbcStatusCode {{.Prefix}}ConnectionGetObjects(struct AdbcConnection* cnxn, int depth, const char* catalog, const char* dbSchema, const char* tableName, const char** tableType, const char* columnName, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionGetOption(struct AdbcConnection*, const char*, char*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionGetOptionBytes(struct AdbcConnection*, const char*, uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionGetOptionDouble(struct AdbcConnection*, const char*, double*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionGetOptionInt(struct AdbcConnection*, const char*, int64_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionGetStatistics(struct AdbcConnection*, const char*, const char*, const char*, char, struct ArrowArrayStream*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionGetStatisticNames(struct AdbcConnection*, struct ArrowArrayStream*, struct AdbcError*); AdbcStatusCode {{.Prefix}}ConnectionGetTableSchema(struct AdbcConnection* cnxn, const char* catalog, const char* dbSchema, const char* tableName, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode {{.Prefix}}ConnectionGetTableTypes(struct AdbcConnection* cnxn, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionInit(struct AdbcConnection* cnxn, struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); AdbcStatusCode {{.Prefix}}ConnectionReadPartition(struct AdbcConnection* cnxn, const uint8_t* serialized, size_t serializedLen, struct ArrowArrayStream* out, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionCommit(struct AdbcConnection* cnxn, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionRelease(struct AdbcConnection* cnxn, struct AdbcError* err); AdbcStatusCode {{.Prefix}}ConnectionRollback(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementNew(struct AdbcConnection* cnxn, struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementRelease(struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementPrepare(struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementExecuteQuery(struct AdbcStatement* stmt, struct ArrowArrayStream* out, int64_t* affected, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementSetSqlQuery(struct AdbcStatement* stmt, const char* query, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementSetSubstraitPlan(struct AdbcStatement* stmt, const uint8_t* plan, size_t length, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionSetOption(struct AdbcConnection* cnxn, const char* key, const char* val, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionSetOptionBytes(struct AdbcConnection*, const char*, const uint8_t*, size_t, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionSetOptionDouble(struct AdbcConnection*, const char*, double, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionSetOptionInt(struct AdbcConnection*, const char*, int64_t, struct AdbcError*); + AdbcStatusCode {{.Prefix}}StatementBind(struct AdbcStatement* stmt, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode {{.Prefix}}StatementBindStream(struct AdbcStatement* stmt, struct ArrowArrayStream* stream, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementCancel(struct AdbcStatement*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementExecuteQuery(struct AdbcStatement* stmt, struct ArrowArrayStream* out, int64_t* affected, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementExecutePartitions(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcPartitions* partitions, int64_t* affected, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementExecuteSchema(struct AdbcStatement*, struct ArrowSchema*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementGetOption(struct AdbcStatement*, const char*, char*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementGetOptionBytes(struct AdbcStatement*, const char*, uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementGetOptionDouble(struct AdbcStatement*, const char*, double*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementGetOptionInt(struct AdbcStatement*, const char*, int64_t*, struct AdbcError*); AdbcStatusCode {{.Prefix}}StatementGetParameterSchema(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementNew(struct AdbcConnection* cnxn, struct AdbcStatement* stmt, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementPrepare(struct AdbcStatement* stmt, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementRelease(struct AdbcStatement* stmt, struct AdbcError* err); AdbcStatusCode {{.Prefix}}StatementSetOption(struct AdbcStatement* stmt, const char* key, const char* value, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementExecutePartitions(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcPartitions* partitions, int64_t* affected, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementSetOptionBytes(struct AdbcStatement*, const char*, const uint8_t*, size_t, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementSetOptionDouble(struct AdbcStatement*, const char*, double, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementSetOptionInt(struct AdbcStatement*, const char*, int64_t, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementSetSqlQuery(struct AdbcStatement* stmt, const char* query, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementSetSubstraitPlan(struct AdbcStatement* stmt, const uint8_t* plan, size_t length, struct AdbcError* err); + AdbcStatusCode {{.Prefix}}DriverInit(int version, void* rawDriver, struct AdbcError* err); static inline void {{.Prefix}}errRelease(struct AdbcError* error) { @@ -57,3 +87,19 @@ static inline void {{.Prefix}}errRelease(struct AdbcError* error) { } void {{.Prefix}}_release_error(struct AdbcError* error); + +struct {{.Prefix}}Error { + char* message; + char** keys; + uint8_t** values; + size_t* lengths; + int count; +}; + +void {{.Prefix}}ReleaseErrWithDetails(struct AdbcError* error); + +int {{.Prefix}}ErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail {{.Prefix}}ErrorGetDetail(const struct AdbcError* error, int index); + +int {{.Prefix}}ArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, struct ArrowSchema* out); +int {{.Prefix}}ArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, struct ArrowArray* out); diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go index a8a606dd0b..925fd8658e 100644 --- a/go/adbc/pkg/flightsql/driver.go +++ b/go/adbc/pkg/flightsql/driver.go @@ -28,11 +28,22 @@ package main // #cgo CXXFLAGS: -std=c++11 -DADBC_EXPORTING // #include "../../drivermgr/adbc.h" // #include "utils.h" +// #include // #include // #include // // typedef const char cchar_t; // typedef const uint8_t cuint8_t; +// typedef const uint32_t cuint32_t; +// typedef const struct AdbcError ConstAdbcError; +// +// int FlightSQLArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +// int FlightSQLArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); +// const char* FlightSQLArrayStreamGetLastError(struct ArrowArrayStream*); +// void FlightSQLArrayStreamRelease(struct ArrowArrayStream*); +// +// int FlightSQLArrayStreamGetSchemaTrampoline(struct ArrowArrayStream*, struct ArrowSchema*); +// int FlightSQLArrayStreamGetNextTrampoline(struct ArrowArrayStream*, struct ArrowArray*); // // void releasePartitions(struct AdbcPartitions* partitions); // @@ -51,6 +62,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc/driver/flightsql" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/cdata" + "github.com/apache/arrow/go/v13/arrow/memory" "github.com/apache/arrow/go/v13/arrow/memory/mallocator" ) @@ -78,14 +90,63 @@ func setErr(err *C.struct_AdbcError, format string, vals ...interface{}) { err.release = (*[0]byte)(C.FlightSQL_release_error) } +func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) { + if err == nil { + return + } + + if err.vendor_code != C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA { + setErr(err, adbcError.Msg) + return + } + + cErrPtr := C.malloc(C.sizeof_struct_FlightSQLError) + cErr := (*C.struct_FlightSQLError)(cErrPtr) + cErr.message = C.CString(adbcError.Msg) + err.message = cErr.message + err.release = (*[0]byte)(C.FlightSQLReleaseErrWithDetails) + err.private_data = cErrPtr + + numDetails := len(adbcError.Details) + if numDetails > 0 { + cErr.keys = (**C.cchar_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cchar_t)(nil))))) + cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil))))) + cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t)) + + keys := fromCArr[*C.cchar_t](cErr.keys, numDetails) + values := fromCArr[*C.cuint8_t](cErr.values, numDetails) + lengths := fromCArr[C.size_t](cErr.lengths, numDetails) + + for i, detail := range adbcError.Details { + keys[i] = C.CString(detail.Key()) + bytes, err := detail.Serialize() + if err != nil { + msg := err.Error() + values[i] = (*C.cuint8_t)(unsafe.Pointer(C.CString(msg))) + lengths[i] = C.size_t(len(msg)) + } else { + values[i] = (*C.cuint8_t)(C.malloc(C.size_t(len(bytes)))) + sink := fromCArr[byte]((*byte)(values[i]), len(bytes)) + copy(sink, bytes) + lengths[i] = C.size_t(len(bytes)) + } + } + } else { + cErr.keys = nil + cErr.values = nil + cErr.lengths = nil + } + cErr.count = C.int(numDetails) +} + func errToAdbcErr(adbcerr *C.struct_AdbcError, err error) adbc.Status { - if adbcerr == nil || err == nil { + if err == nil { return adbc.StatusOK } var adbcError adbc.Error if errors.As(err, &adbcError) { - setErr(adbcerr, adbcError.Msg) + setErrWithDetails(adbcerr, adbcError) return adbcError.Code } @@ -123,6 +184,45 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T { return cgo.Handle((uintptr)(*hptr)).Value().(*T) } +func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode { + lenWithTerminator := C.size_t(len(val) + 1) + if lenWithTerminator <= *length { + sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length)) + copy(sink, val) + sink[lenWithTerminator] = 0 + } + *length = lenWithTerminator + return C.ADBC_STATUS_OK +} + +func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode { + if C.size_t(len(val)) <= *length { + sink := fromCArr[byte]((*byte)(out), int(*length)) + copy(sink, val) + } + *length = C.size_t(len(val)) + return C.ADBC_STATUS_OK +} + +type cancellableContext struct { + ctx context.Context + cancel context.CancelFunc +} + +func (c *cancellableContext) newContext() context.Context { + c.cancelContext() + c.ctx, c.cancel = context.WithCancel(context.Background()) + return c.ctx +} + +func (c *cancellableContext) cancelContext() { + if c.cancel != nil { + c.cancel() + } + c.ctx = nil + c.cancel = nil +} + func checkDBAlloc(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) @@ -152,48 +252,243 @@ func checkDBInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname strin return cdb } +// Custom ArrowArrayStream export to support ADBC error data in ArrowArrayStream + +type cArrayStream struct { + rdr array.RecordReader + // Must be C-allocated + adbcErr *C.struct_AdbcError + status C.AdbcStatusCode +} + +func (cStream *cArrayStream) maybeError() C.int { + err := cStream.rdr.Err() + if err != nil { + if cStream.adbcErr != nil { + C.FlightSQLerrRelease(cStream.adbcErr) + } else { + cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + } + cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA + cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) + switch adbc.Status(cStream.status) { + case adbc.StatusUnknown: + return C.EIO + case adbc.StatusNotImplemented: + return C.ENOTSUP + case adbc.StatusNotFound: + return C.ENOENT + case adbc.StatusAlreadyExists: + return C.EEXIST + case adbc.StatusInvalidArgument: + return C.EINVAL + case adbc.StatusInvalidState: + return C.EINVAL + case adbc.StatusInvalidData: + return C.EIO + case adbc.StatusIntegrity: + return C.EIO + case adbc.StatusInternal: + return C.EIO + case adbc.StatusIO: + return C.EIO + case adbc.StatusCancelled: + return C.ECANCELED + case adbc.StatusTimeout: + return C.ETIMEDOUT + case adbc.StatusUnauthenticated: + return C.EACCES + case adbc.StatusUnauthorized: + return C.EACCES + default: + return C.EIO + } + } + return 0 +} + +//export FlightSQLArrayStreamGetLastError +func FlightSQLArrayStreamGetLastError(stream *C.struct_ArrowArrayStream) *C.cchar_t { + if stream == nil || stream.release != (*[0]byte)(C.FlightSQLArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.adbcErr != nil { + return cStream.adbcErr.message + } + return nil +} + +//export FlightSQLArrayStreamGetNext +func FlightSQLArrayStreamGetNext(stream *C.struct_ArrowArrayStream, array *C.struct_ArrowArray) C.int { + if stream == nil || stream.release != (*[0]byte)(C.FlightSQLArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.rdr.Next() { + cdata.ExportArrowRecordBatch(cStream.rdr.Record(), toCdataArray(array), nil) + return 0 + } + array.release = nil + array.private_data = nil + return cStream.maybeError() +} + +//export FlightSQLArrayStreamGetSchema +func FlightSQLArrayStreamGetSchema(stream *C.struct_ArrowArrayStream, schema *C.struct_ArrowSchema) C.int { + if stream == nil || stream.release != (*[0]byte)(C.FlightSQLArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + s := cStream.rdr.Schema() + if s == nil { + return cStream.maybeError() + } + cdata.ExportArrowSchema(s, toCdataSchema(schema)) + return 0 +} + +//export FlightSQLArrayStreamRelease +func FlightSQLArrayStreamRelease(stream *C.struct_ArrowArrayStream) { + if stream == nil || stream.release != (*[0]byte)(C.FlightSQLArrayStreamRelease) { + return + } + h := (*(*cgo.Handle)(stream.private_data)) + + cStream := h.Value().(*cArrayStream) + cStream.rdr.Release() + if cStream.adbcErr != nil { + C.FlightSQLerrRelease(cStream.adbcErr) + C.free(unsafe.Pointer(cStream.adbcErr)) + } + C.free(unsafe.Pointer(stream.private_data)) + stream.private_data = nil + h.Delete() + runtime.GC() +} + +//export FlightSQLErrorFromArrayStream +func FlightSQLErrorFromArrayStream(stream *C.struct_ArrowArrayStream, status *C.AdbcStatusCode) *C.struct_AdbcError { + if stream == nil || stream.release != (*[0]byte)(C.FlightSQLArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if status != nil { + *status = cStream.status + } + return cStream.adbcErr +} + +func exportRecordReader(rdr array.RecordReader, stream *C.struct_ArrowArrayStream) { + cStream := &cArrayStream{rdr: rdr, status: C.ADBC_STATUS_OK} + stream.get_last_error = (*[0]byte)(C.FlightSQLArrayStreamGetLastError) + stream.get_next = (*[0]byte)(C.FlightSQLArrayStreamGetNextTrampoline) + stream.get_schema = (*[0]byte)(C.FlightSQLArrayStreamGetSchemaTrampoline) + stream.release = (*[0]byte)(C.FlightSQLArrayStreamRelease) + hndl := cgo.NewHandle(cStream) + stream.private_data = createHandle(hndl) + rdr.Retain() +} + type cDatabase struct { opts map[string]string db adbc.Database } -//export FlightSQLDatabaseNew -func FlightSQLDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export FlightSQLDatabaseGetOption +func FlightSQLDatabaseGetOption(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseNew", e) + code = poison(err, "AdbcDatabaseGetOption", e) } }() - if atomic.LoadInt32(&globalPoison) != 0 { - setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") - return C.ADBC_STATUS_INTERNAL + cdb := checkDBInit(db, err, "AdbcDatabaseGetOption") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE } - if db.private_data != nil { - setErr(err, "AdbcDatabaseNew: database already allocated") + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export FlightSQLDatabaseGetOptionBytes +func FlightSQLDatabaseGetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionBytes") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - dbobj := &cDatabase{opts: make(map[string]string)} - hndl := cgo.NewHandle(dbobj) - db.private_data = createHandle(hndl) - return C.ADBC_STATUS_OK + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) } -//export FlightSQLDatabaseSetOption -func FlightSQLDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export FlightSQLDatabaseGetOptionDouble +func FlightSQLDatabaseGetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseSetOption", e) + code = poison(err, "AdbcDatabaseGetOptionDouble", e) } }() - if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionDouble") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - cdb := getFromHandle[cDatabase](db.private_data) - k, v := C.GoString(key), C.GoString(value) - cdb.opts[k] = v + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return C.ADBC_STATUS_OK + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export FlightSQLDatabaseGetOptionInt +func FlightSQLDatabaseGetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export FlightSQLDatabaseInit @@ -222,6 +517,27 @@ func FlightSQLDatabaseInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) ( return C.ADBC_STATUS_OK } +//export FlightSQLDatabaseNew +func FlightSQLDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseNew", e) + } + }() + if atomic.LoadInt32(&globalPoison) != 0 { + setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") + return C.ADBC_STATUS_INTERNAL + } + if db.private_data != nil { + setErr(err, "AdbcDatabaseNew: database already allocated") + return C.ADBC_STATUS_INVALID_STATE + } + dbobj := &cDatabase{opts: make(map[string]string)} + hndl := cgo.NewHandle(dbobj) + db.private_data = createHandle(hndl) + return C.ADBC_STATUS_OK +} + //export FlightSQLDatabaseRelease func FlightSQLDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -250,7 +566,99 @@ func FlightSQLDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError return C.ADBC_STATUS_OK } +//export FlightSQLDatabaseSetOption +func FlightSQLDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOption", e) + } + }() + if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + return C.ADBC_STATUS_INVALID_STATE + } + cdb := getFromHandle[cDatabase](db.private_data) + + k, v := C.GoString(key), C.GoString(value) + if cdb.db != nil { + opts, ok := cdb.db.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(k, v))) + } else { + cdb.opts[k] = v + } + + return C.ADBC_STATUS_OK +} + +//export FlightSQLDatabaseSetOptionBytes +func FlightSQLDatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionBytes") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export FlightSQLDatabaseSetOptionDouble +func FlightSQLDatabaseSetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionDouble", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionDouble") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export FlightSQLDatabaseSetOptionInt +func FlightSQLDatabaseSetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) +} + type cConn struct { + cancellableContext + cnxn adbc.Connection initArgs map[string]string } @@ -284,6 +692,102 @@ func checkConnInit(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname return conn } +//export FlightSQLConnectionGetOption +func FlightSQLConnectionGetOption(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOption", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOption") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export FlightSQLConnectionGetOptionBytes +func FlightSQLConnectionGetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export FlightSQLConnectionGetOptionDouble +func FlightSQLConnectionGetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export FlightSQLConnectionGetOptionInt +func FlightSQLConnectionGetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + //export FlightSQLConnectionNew func FlightSQLConnectionNew(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -327,13 +831,75 @@ func FlightSQLConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.cch return C.ADBC_STATUS_OK } - opts, ok := conn.cnxn.(adbc.PostInitOptions) + opts, ok := conn.cnxn.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcConnectionSetOption: not supported post-init") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))) +} + +//export FlightSQLConnectionSetOptionBytes +func FlightSQLConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export FlightSQLConnectionSetOptionDouble +func FlightSQLConnectionSetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export FlightSQLConnectionSetOptionInt +func FlightSQLConnectionSetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) if !ok { - setErr(err, "AdbcConnectionSetOption: not supported post-init") + setErr(err, "AdbcConnectionSetOptionInt: options are not supported") return C.ADBC_STATUS_NOT_IMPLEMENTED } - rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val))) - return C.AdbcStatusCode(rawCode) + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export FlightSQLConnectionInit @@ -396,8 +962,9 @@ func FlightSQLConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_Adb conn := h.Value().(*cConn) defer func() { + conn.cancelContext() conn.cnxn = nil - C.free(unsafe.Pointer(cnxn.private_data)) + C.free(cnxn.private_data) cnxn.private_data = nil h.Delete() // manually trigger GC for two reasons: @@ -434,26 +1001,19 @@ func toCdataArray(ptr *C.struct_ArrowArray) *cdata.CArrowArray { return (*cdata.CArrowArray)(unsafe.Pointer(ptr)) } -//export FlightSQLConnectionGetInfo -func FlightSQLConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.uint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export FlightSQLConnectionCancel +func FlightSQLConnectionCancel(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcConnectionGetInfo", e) + code = poison(err, "AdbcConnectionCancel", e) } }() - conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + conn := checkConnInit(cnxn, err, "AdbcConnectionCancel") if conn == nil { return C.ADBC_STATUS_INVALID_STATE } - infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) - rdr, e := conn.cnxn.GetInfo(context.Background(), infoCodes) - if e != nil { - return C.AdbcStatusCode(errToAdbcErr(err, e)) - } - - defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + conn.cancelContext() return C.ADBC_STATUS_OK } @@ -481,6 +1041,29 @@ func toStrSlice(in **C.cchar_t) []string { return out } +//export FlightSQLConnectionGetInfo +func FlightSQLConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetInfo", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) + rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + //export FlightSQLConnectionGetObjects func FlightSQLConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, catalog, dbSchema, tableName *C.cchar_t, tableType **C.cchar_t, columnName *C.cchar_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { @@ -494,12 +1077,67 @@ func FlightSQLConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, c return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetObjects(context.Background(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + rdr, e := conn.cnxn.GetObjects(conn.newContext(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export FlightSQLConnectionGetStatistics +func FlightSQLConnectionGetStatistics(cnxn *C.struct_AdbcConnection, catalog, dbSchema, tableName *C.cchar_t, approximate C.char, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatistics(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), int(approximate) != 0) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export FlightSQLConnectionGetStatisticNames +func FlightSQLConnectionGetStatisticNames(cnxn *C.struct_AdbcConnection, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatisticNames(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -515,7 +1153,7 @@ func FlightSQLConnectionGetTableSchema(cnxn *C.struct_AdbcConnection, catalog, d return C.ADBC_STATUS_INVALID_STATE } - sc, e := conn.cnxn.GetTableSchema(context.Background(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) + sc, e := conn.cnxn.GetTableSchema(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -535,12 +1173,12 @@ func FlightSQLConnectionGetTableTypes(cnxn *C.struct_AdbcConnection, out *C.stru return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetTableTypes(context.Background()) + rdr, e := conn.cnxn.GetTableTypes(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -556,12 +1194,12 @@ func FlightSQLConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialized return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.ReadPartition(context.Background(), fromCArr[byte](serialized, int(serializedLen))) + rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen))) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -577,7 +1215,7 @@ func FlightSQLConnectionCommit(cnxn *C.struct_AdbcConnection, err *C.struct_Adbc return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(conn.newContext()))) } //export FlightSQLConnectionRollback @@ -592,25 +1230,137 @@ func FlightSQLConnectionRollback(cnxn *C.struct_AdbcConnection, err *C.struct_Ad return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(conn.newContext()))) +} + +type cStmt struct { + cancellableContext + + stmt adbc.Statement } -func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) adbc.Statement { +func checkStmtAlloc(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) - return nil + return false } if stmt == nil { setErr(err, "%s: statement not allocated", fname) - return nil + return false } - if stmt.private_data == nil { - setErr(err, "%s: statement not initialized", fname) + setErr(err, "%s: statement not allocated", fname) + return false + } + return true +} + +func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) *cStmt { + if !checkStmtAlloc(stmt, err, fname) { + return nil + } + cStmt := getFromHandle[cStmt](stmt.private_data) + if cStmt.stmt == nil { + setErr(err, "%s: statement not allocated", fname) return nil } + return cStmt +} + +//export FlightSQLStatementGetOption +func FlightSQLStatementGetOption(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOption", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOption") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export FlightSQLStatementGetOptionBytes +func FlightSQLStatementGetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export FlightSQLStatementGetOptionDouble +func FlightSQLStatementGetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export FlightSQLStatementGetOptionInt +func FlightSQLStatementGetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return (*(*cgo.Handle)(stmt.private_data)).Value().(adbc.Statement) + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export FlightSQLStatementNew @@ -639,8 +1389,8 @@ func FlightSQLStatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcSta return C.AdbcStatusCode(errToAdbcErr(err, e)) } - h := cgo.NewHandle(st) - stmt.private_data = createHandle(h) + hndl := cgo.NewHandle(&cStmt{stmt: st}) + stmt.private_data = createHandle(hndl) return C.ADBC_STATUS_OK } @@ -655,31 +1405,46 @@ func FlightSQLStatementRelease(stmt *C.struct_AdbcStatement, err *C.struct_AdbcE setErr(err, "AdbcStatementRelease: Go panicked, driver is in unknown state") return C.ADBC_STATUS_INTERNAL } - if stmt == nil { - setErr(err, "AdbcStatementRelease: statement not allocated") + if !checkStmtAlloc(stmt, err, "AdbcStatementRelease") { return C.ADBC_STATUS_INVALID_STATE } + h := (*(*cgo.Handle)(stmt.private_data)) - if stmt.private_data == nil { - setErr(err, "AdbcStatementRelease: statement not initialized") - return C.ADBC_STATUS_INVALID_STATE + st := h.Value().(*cStmt) + defer func() { + st.cancelContext() + st.stmt = nil + C.free(stmt.private_data) + stmt.private_data = nil + h.Delete() + // manually trigger GC for two reasons: + // 1. ASAN expects the release callback to be called before + // the process ends, but GC is not deterministic. So by manually + // triggering the GC we ensure the release callback gets called. + // 2. Creates deterministic GC behavior by all Release functions + // triggering a garbage collection + runtime.GC() + }() + if st.stmt == nil { + return C.ADBC_STATUS_OK } + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Close())) +} - h := (*(*cgo.Handle)(stmt.private_data)) - st := h.Value().(adbc.Statement) - C.free(stmt.private_data) - stmt.private_data = nil +//export FlightSQLStatementCancel +func FlightSQLStatementCancel(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementCancel", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementCancel") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } - e := st.Close() - h.Delete() - // manually trigger GC for two reasons: - // 1. ASAN expects the release callback to be called before - // the process ends, but GC is not deterministic. So by manually - // triggering the GC we ensure the release callback gets called. - // 2. Creates deterministic GC behavior by all Release functions - // triggering a garbage collection - runtime.GC() - return C.AdbcStatusCode(errToAdbcErr(err, e)) + st.cancelContext() + return C.ADBC_STATUS_OK } //export FlightSQLStatementPrepare @@ -694,7 +1459,7 @@ func FlightSQLStatementPrepare(stmt *C.struct_AdbcStatement, err *C.struct_AdbcE return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.Prepare(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Prepare(st.newContext()))) } //export FlightSQLStatementExecuteQuery @@ -710,7 +1475,7 @@ func FlightSQLStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ } if out == nil { - n, e := st.ExecuteUpdate(context.Background()) + n, e := st.stmt.ExecuteUpdate(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -719,7 +1484,7 @@ func FlightSQLStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ *affected = C.int64_t(n) } } else { - rdr, n, e := st.ExecuteQuery(context.Background()) + rdr, n, e := st.stmt.ExecuteQuery(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -729,8 +1494,35 @@ func FlightSQLStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) + } + return C.ADBC_STATUS_OK +} + +//export FlightSQLStatementExecuteSchema +func FlightSQLStatementExecuteSchema(stmt *C.struct_AdbcStatement, schema *C.struct_ArrowSchema, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementExecuteQuery", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementExecuteQuery") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + es, ok := st.stmt.(adbc.StatementExecuteSchema) + if !ok { + setErr(err, "AdbcStatementExecuteSchema: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + sc, e := es.ExecuteSchema(st.newContext()) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) } + + cdata.ExportArrowSchema(sc, toCdataSchema(schema)) return C.ADBC_STATUS_OK } @@ -746,7 +1538,7 @@ func FlightSQLStatementSetSqlQuery(stmt *C.struct_AdbcStatement, query *C.cchar_ return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSqlQuery(C.GoString(query)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSqlQuery(C.GoString(query)))) } //export FlightSQLStatementSetSubstraitPlan @@ -761,7 +1553,7 @@ func FlightSQLStatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C.cu return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) } //export FlightSQLStatementBind @@ -784,7 +1576,7 @@ func FlightSQLStatementBind(stmt *C.struct_AdbcStatement, values *C.struct_Arrow } defer rec.Release() - return C.AdbcStatusCode(errToAdbcErr(err, st.Bind(context.Background(), rec))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Bind(st.newContext(), rec))) } //export FlightSQLStatementBindStream @@ -803,7 +1595,7 @@ func FlightSQLStatementBindStream(stmt *C.struct_AdbcStatement, stream *C.struct if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } - return C.AdbcStatusCode(errToAdbcErr(err, st.BindStream(context.Background(), rdr.(array.RecordReader)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.BindStream(st.newContext(), rdr.(array.RecordReader)))) } //export FlightSQLStatementGetParameterSchema @@ -818,7 +1610,7 @@ func FlightSQLStatementGetParameterSchema(stmt *C.struct_AdbcStatement, schema * return C.ADBC_STATUS_INVALID_STATE } - sc, e := st.GetParameterSchema() + sc, e := st.stmt.GetParameterSchema() if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -839,7 +1631,70 @@ func FlightSQLStatementSetOption(stmt *C.struct_AdbcStatement, key, value *C.cch return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetOption(C.GoString(key), C.GoString(value)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetOption(C.GoString(key), C.GoString(value)))) +} + +//export FlightSQLStatementSetOptionBytes +func FlightSQLStatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export FlightSQLStatementSetOptionDouble +func FlightSQLStatementSetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export FlightSQLStatementSetOptionInt +func FlightSQLStatementSetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export releasePartitions @@ -868,7 +1723,7 @@ func FlightSQLStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema *C return C.ADBC_STATUS_INVALID_STATE } - sc, part, n, e := st.ExecutePartitions(context.Background()) + sc, part, n, e := st.stmt.ExecutePartitions(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -913,13 +1768,20 @@ func FlightSQLStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema *C //export FlightSQLDriverInit func FlightSQLDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcError) C.AdbcStatusCode { - if version != C.ADBC_VERSION_1_0_0 { - setErr(err, "Only version %d supported, got %d", int(C.ADBC_VERSION_1_0_0), int(version)) + driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) + + switch version { + case C.ADBC_VERSION_1_0_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE) + memory.Set(sink, 0) + case C.ADBC_VERSION_1_1_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE) + memory.Set(sink, 0) + default: + setErr(err, "Only version 1.0.0/1.1.0 supported, got %d", int(version)) return C.ADBC_STATUS_NOT_IMPLEMENTED } - driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) - C.memset(unsafe.Pointer(driver), 0, C.sizeof_struct_AdbcDriver) driver.DatabaseInit = (*[0]byte)(C.FlightSQLDatabaseInit) driver.DatabaseNew = (*[0]byte)(C.FlightSQLDatabaseNew) driver.DatabaseRelease = (*[0]byte)(C.FlightSQLDatabaseRelease) @@ -949,6 +1811,41 @@ func FlightSQLDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcErr driver.StatementGetParameterSchema = (*[0]byte)(C.FlightSQLStatementGetParameterSchema) driver.StatementPrepare = (*[0]byte)(C.FlightSQLStatementPrepare) + if version == C.ADBC_VERSION_1_1_0 { + driver.ErrorGetDetailCount = (*[0]byte)(C.FlightSQLErrorGetDetailCount) + driver.ErrorGetDetail = (*[0]byte)(C.FlightSQLErrorGetDetail) + driver.ErrorFromArrayStream = (*[0]byte)(C.FlightSQLErrorFromArrayStream) + + driver.DatabaseGetOption = (*[0]byte)(C.FlightSQLDatabaseGetOption) + driver.DatabaseGetOptionBytes = (*[0]byte)(C.FlightSQLDatabaseGetOptionBytes) + driver.DatabaseGetOptionDouble = (*[0]byte)(C.FlightSQLDatabaseGetOptionDouble) + driver.DatabaseGetOptionInt = (*[0]byte)(C.FlightSQLDatabaseGetOptionInt) + driver.DatabaseSetOptionBytes = (*[0]byte)(C.FlightSQLDatabaseSetOptionBytes) + driver.DatabaseSetOptionDouble = (*[0]byte)(C.FlightSQLDatabaseSetOptionDouble) + driver.DatabaseSetOptionInt = (*[0]byte)(C.FlightSQLDatabaseSetOptionInt) + + driver.ConnectionCancel = (*[0]byte)(C.FlightSQLConnectionCancel) + driver.ConnectionGetOption = (*[0]byte)(C.FlightSQLConnectionGetOption) + driver.ConnectionGetOptionBytes = (*[0]byte)(C.FlightSQLConnectionGetOptionBytes) + driver.ConnectionGetOptionDouble = (*[0]byte)(C.FlightSQLConnectionGetOptionDouble) + driver.ConnectionGetOptionInt = (*[0]byte)(C.FlightSQLConnectionGetOptionInt) + driver.ConnectionGetStatistics = (*[0]byte)(C.FlightSQLConnectionGetStatistics) + driver.ConnectionGetStatisticNames = (*[0]byte)(C.FlightSQLConnectionGetStatisticNames) + driver.ConnectionSetOptionBytes = (*[0]byte)(C.FlightSQLConnectionSetOptionBytes) + driver.ConnectionSetOptionDouble = (*[0]byte)(C.FlightSQLConnectionSetOptionDouble) + driver.ConnectionSetOptionInt = (*[0]byte)(C.FlightSQLConnectionSetOptionInt) + + driver.StatementCancel = (*[0]byte)(C.FlightSQLStatementCancel) + driver.StatementExecuteSchema = (*[0]byte)(C.FlightSQLStatementExecuteSchema) + driver.StatementGetOption = (*[0]byte)(C.FlightSQLStatementGetOption) + driver.StatementGetOptionBytes = (*[0]byte)(C.FlightSQLStatementGetOptionBytes) + driver.StatementGetOptionDouble = (*[0]byte)(C.FlightSQLStatementGetOptionDouble) + driver.StatementGetOptionInt = (*[0]byte)(C.FlightSQLStatementGetOptionInt) + driver.StatementSetOptionBytes = (*[0]byte)(C.FlightSQLStatementSetOptionBytes) + driver.StatementSetOptionDouble = (*[0]byte)(C.FlightSQLStatementSetOptionDouble) + driver.StatementSetOptionInt = (*[0]byte)(C.FlightSQLStatementSetOptionInt) + } + return C.ADBC_STATUS_OK } diff --git a/go/adbc/pkg/flightsql/utils.c b/go/adbc/pkg/flightsql/utils.c index 41777a98c4..95920aa498 100644 --- a/go/adbc/pkg/flightsql/utils.c +++ b/go/adbc/pkg/flightsql/utils.c @@ -35,52 +35,142 @@ void FlightSQL_release_error(struct AdbcError* error) { error->release = NULL; } -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return FlightSQLDatabaseNew(database, error); +void FlightSQLReleaseErrWithDetails(struct AdbcError* error) { + if (!error || error->release != FlightSQLReleaseErrWithDetails || + !error->private_data) { + return; + } + + struct FlightSQLError* details = + (struct FlightSQLError*) error->private_data; + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + free(details->keys); + free(details->values); + free(details->lengths); + free(details); + + free(error->message); + error->message = NULL; + error->release = NULL; + error->private_data = NULL; } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return FlightSQLDatabaseSetOption(database, key, value, error); +int FlightSQLErrorGetDetailCount(const struct AdbcError* error) { + if (!error || error->release != FlightSQLReleaseErrWithDetails || + !error->private_data) { + return 0; + } + + return ((struct FlightSQLError*) error->private_data)->count; +} + +struct AdbcErrorDetail FlightSQLErrorGetDetail(const struct AdbcError* error, + int index) { + if (!error || error->release != FlightSQLReleaseErrWithDetails || + !error->private_data) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct FlightSQLError* details = (struct FlightSQLError*) error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index] + }; +} + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return FlightSQLErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return FlightSQLErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return FlightSQLErrorFromArrayStream(stream, status); +} + +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return FlightSQLDatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return FlightSQLDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return FlightSQLDatabaseGetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return FlightSQLDatabaseGetOptionInt(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return FlightSQLDatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return FlightSQLDatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return FlightSQLDatabaseRelease(database, error); } -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, - struct AdbcError* error) { - return FlightSQLConnectionNew(connection, error); +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return FlightSQLDatabaseSetOption(database, key, value, error); } -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, - const char* value, struct AdbcError* error) { - return FlightSQLConnectionSetOption(connection, key, value, error); +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return FlightSQLDatabaseSetOptionBytes(database, key, value, length, error); } -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, - struct AdbcDatabase* database, - struct AdbcError* error) { - return FlightSQLConnectionInit(connection, database, error); +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return FlightSQLDatabaseSetOptionDouble(database, key, value, error); } -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, - struct AdbcError* error) { - return FlightSQLConnectionRelease(connection, error); +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return FlightSQLDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return FlightSQLConnectionCancel(connection, error); +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + return FlightSQLConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); - return FlightSQLConnectionGetInfo(connection, info_codes, info_codes_length, out, - error); + return FlightSQLConnectionGetInfo(connection, info_codes, info_codes_length, + out, error); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -91,7 +181,46 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return FlightSQLConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_type, column_name, out, error); + table_type, column_name, out, error); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return FlightSQLConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return FlightSQLConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return FlightSQLConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return FlightSQLConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return FlightSQLConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return FlightSQLConnectionGetStatisticNames(connection, out, error); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -101,7 +230,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, struct AdbcError* error) { if (schema) memset(schema, 0, sizeof(*schema)); return FlightSQLConnectionGetTableSchema(connection, catalog, db_schema, table_name, - schema, error); + schema, error); } AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, @@ -111,6 +240,17 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, return FlightSQLConnectionGetTableTypes(connection, out, error); } +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return FlightSQLConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + return FlightSQLConnectionNew(connection, error); +} + AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, const uint8_t* serialized_partition, size_t serialized_length, @@ -118,12 +258,12 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return FlightSQLConnectionReadPartition(connection, serialized_partition, - serialized_length, out, error); + serialized_length, out, error); } -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, - struct AdbcError* error) { - return FlightSQLConnectionCommit(connection, error); +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + return FlightSQLConnectionRelease(connection, error); } AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, @@ -131,39 +271,32 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return FlightSQLConnectionRollback(connection, error); } -AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, - struct AdbcStatement* statement, - struct AdbcError* error) { - return FlightSQLStatementNew(connection, statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, - struct AdbcError* error) { - return FlightSQLStatementRelease(statement, error); +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error) { + return FlightSQLConnectionSetOption(connection, key, value, error); } -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, - struct ArrowArrayStream* out, - int64_t* rows_affected, - struct AdbcError* error) { - if (out) memset(out, 0, sizeof(*out)); - return FlightSQLStatementExecuteQuery(statement, out, rows_affected, error); +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return FlightSQLConnectionSetOptionBytes(connection, key, value, length, error); } -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, - struct AdbcError* error) { - return FlightSQLStatementPrepare(statement, error); +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return FlightSQLConnectionSetOptionDouble(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, - const char* query, struct AdbcError* error) { - return FlightSQLStatementSetSqlQuery(statement, query, error); +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return FlightSQLConnectionSetOptionInt(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, - const uint8_t* plan, size_t length, - struct AdbcError* error) { - return FlightSQLStatementSetSubstraitPlan(statement, plan, length, error); +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return FlightSQLStatementCancel(statement, error); } AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, @@ -178,6 +311,56 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return FlightSQLStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + if (partitions) memset(partitions, 0, sizeof(*partitions)); + return FlightSQLStatementExecutePartitions(statement, schema, partitions, + rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + if (out) memset(out, 0, sizeof(*out)); + return FlightSQLStatementExecuteQuery(statement, out, rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + return FlightSQLStatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return FlightSQLStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return FlightSQLStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return FlightSQLStatementGetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + return FlightSQLStatementGetOptionInt(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -185,20 +368,54 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, return FlightSQLStatementGetParameterSchema(statement, schema, error); } +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + return FlightSQLStatementNew(connection, statement, error); +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + return FlightSQLStatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + return FlightSQLStatementRelease(statement, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + return FlightSQLStatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + return FlightSQLStatementSetSubstraitPlan(statement, plan, length, error); +} + AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error) { return FlightSQLStatementSetOption(statement, key, value, error); } -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, - struct ArrowSchema* schema, - struct AdbcPartitions* partitions, - int64_t* rows_affected, - struct AdbcError* error) { - if (schema) memset(schema, 0, sizeof(*schema)); - if (partitions) memset(partitions, 0, sizeof(*partitions)); - return FlightSQLStatementExecutePartitions(statement, schema, partitions, rows_affected, - error); +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return FlightSQLStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return FlightSQLStatementSetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + return FlightSQLStatementSetOptionInt(statement, key, value, error); } ADBC_EXPORT @@ -206,6 +423,23 @@ AdbcStatusCode AdbcDriverInit(int version, void* driver, struct AdbcError* error return FlightSQLDriverInit(version, driver, error); } +int FlightSQLArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +int FlightSQLArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); + +int FlightSQLArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return FlightSQLArrayStreamGetSchema(stream, out); +} + +int FlightSQLArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return FlightSQLArrayStreamGetNext(stream, out); +} + #ifdef __cplusplus } #endif diff --git a/go/adbc/pkg/flightsql/utils.h b/go/adbc/pkg/flightsql/utils.h index 51a67f240a..e3b22fb737 100644 --- a/go/adbc/pkg/flightsql/utils.h +++ b/go/adbc/pkg/flightsql/utils.h @@ -26,72 +26,151 @@ #include #include "../../drivermgr/adbc.h" +struct AdbcError* FlightSQLErrorFromArrayStream(struct ArrowArrayStream*, + AdbcStatusCode*); +AdbcStatusCode FlightSQLDatabaseGetOption(struct AdbcDatabase*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseGetOptionBytes(struct AdbcDatabase*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseGetOptionDouble(struct AdbcDatabase*, const char*, + double*, struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseGetOptionInt(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode FlightSQLDatabaseNew(struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode FlightSQLDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode FlightSQLDatabaseSetOption(struct AdbcDatabase* db, const char* key, const char* value, struct AdbcError* err); -AdbcStatusCode FlightSQLDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode FlightSQLDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionSetOption(struct AdbcConnection* cnxn, const char* key, - const char* val, struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionInit(struct AdbcConnection* cnxn, - struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionRelease(struct AdbcConnection* cnxn, - struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionGetInfo(struct AdbcConnection* cnxn, uint32_t* codes, - size_t len, struct ArrowArrayStream* out, +AdbcStatusCode FlightSQLDatabaseSetOptionBytes(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseSetOptionDouble(struct AdbcDatabase*, const char*, double, + struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseSetOptionInt(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + +AdbcStatusCode FlightSQLConnectionCancel(struct AdbcConnection*, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionCommit(struct AdbcConnection* cnxn, + struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionGetInfo(struct AdbcConnection* cnxn, + const uint32_t* codes, size_t len, + struct ArrowArrayStream* out, struct AdbcError* err); AdbcStatusCode FlightSQLConnectionGetObjects( struct AdbcConnection* cnxn, int depth, const char* catalog, const char* dbSchema, const char* tableName, const char** tableType, const char* columnName, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionGetOption(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionGetOptionBytes(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionGetOptionDouble(struct AdbcConnection*, const char*, + double*, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionGetOptionInt(struct AdbcConnection*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionGetStatistics(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, + struct AdbcError*); +AdbcStatusCode FlightSQLConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); AdbcStatusCode FlightSQLConnectionGetTableSchema( struct AdbcConnection* cnxn, const char* catalog, const char* dbSchema, const char* tableName, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode FlightSQLConnectionGetTableTypes(struct AdbcConnection* cnxn, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionInit(struct AdbcConnection* cnxn, + struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); AdbcStatusCode FlightSQLConnectionReadPartition(struct AdbcConnection* cnxn, const uint8_t* serialized, size_t serializedLen, struct ArrowArrayStream* out, struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionCommit(struct AdbcConnection* cnxn, - struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionRelease(struct AdbcConnection* cnxn, + struct AdbcError* err); AdbcStatusCode FlightSQLConnectionRollback(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode FlightSQLStatementNew(struct AdbcConnection* cnxn, - struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode FlightSQLStatementRelease(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode FlightSQLStatementPrepare(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode FlightSQLStatementExecuteQuery(struct AdbcStatement* stmt, - struct ArrowArrayStream* out, - int64_t* affected, struct AdbcError* err); -AdbcStatusCode FlightSQLStatementSetSqlQuery(struct AdbcStatement* stmt, - const char* query, struct AdbcError* err); -AdbcStatusCode FlightSQLStatementSetSubstraitPlan(struct AdbcStatement* stmt, - const uint8_t* plan, size_t length, - struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionSetOption(struct AdbcConnection* cnxn, const char* key, + const char* val, struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode FlightSQLConnectionSetOptionDouble(struct AdbcConnection*, const char*, + double, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionSetOptionInt(struct AdbcConnection*, const char*, + int64_t, struct AdbcError*); + AdbcStatusCode FlightSQLStatementBind(struct AdbcStatement* stmt, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode FlightSQLStatementBindStream(struct AdbcStatement* stmt, struct ArrowArrayStream* stream, struct AdbcError* err); -AdbcStatusCode FlightSQLStatementGetParameterSchema(struct AdbcStatement* stmt, - struct ArrowSchema* schema, - struct AdbcError* err); -AdbcStatusCode FlightSQLStatementSetOption(struct AdbcStatement* stmt, const char* key, - const char* value, struct AdbcError* err); +AdbcStatusCode FlightSQLStatementCancel(struct AdbcStatement*, struct AdbcError*); +AdbcStatusCode FlightSQLStatementExecuteQuery(struct AdbcStatement* stmt, + struct ArrowArrayStream* out, + int64_t* affected, struct AdbcError* err); AdbcStatusCode FlightSQLStatementExecutePartitions(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcPartitions* partitions, int64_t* affected, struct AdbcError* err); +AdbcStatusCode FlightSQLStatementExecuteSchema(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcError*); +AdbcStatusCode FlightSQLStatementGetOption(struct AdbcStatement*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLStatementGetOptionBytes(struct AdbcStatement*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLStatementGetOptionDouble(struct AdbcStatement*, const char*, + double*, struct AdbcError*); +AdbcStatusCode FlightSQLStatementGetOptionInt(struct AdbcStatement*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode FlightSQLStatementGetParameterSchema(struct AdbcStatement* stmt, + struct ArrowSchema* schema, + struct AdbcError* err); +AdbcStatusCode FlightSQLStatementNew(struct AdbcConnection* cnxn, + struct AdbcStatement* stmt, struct AdbcError* err); +AdbcStatusCode FlightSQLStatementPrepare(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode FlightSQLStatementRelease(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode FlightSQLStatementSetOption(struct AdbcStatement* stmt, const char* key, + const char* value, struct AdbcError* err); +AdbcStatusCode FlightSQLStatementSetOptionBytes(struct AdbcStatement*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode FlightSQLStatementSetOptionDouble(struct AdbcStatement*, const char*, + double, struct AdbcError*); +AdbcStatusCode FlightSQLStatementSetOptionInt(struct AdbcStatement*, const char*, int64_t, + struct AdbcError*); +AdbcStatusCode FlightSQLStatementSetSqlQuery(struct AdbcStatement* stmt, + const char* query, struct AdbcError* err); +AdbcStatusCode FlightSQLStatementSetSubstraitPlan(struct AdbcStatement* stmt, + const uint8_t* plan, size_t length, + struct AdbcError* err); + AdbcStatusCode FlightSQLDriverInit(int version, void* rawDriver, struct AdbcError* err); static inline void FlightSQLerrRelease(struct AdbcError* error) { error->release(error); } void FlightSQL_release_error(struct AdbcError* error); + +struct FlightSQLError { + char* message; + char** keys; + uint8_t** values; + size_t* lengths; + int count; +}; + +void FlightSQLReleaseErrWithDetails(struct AdbcError* error); + +int FlightSQLErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail FlightSQLErrorGetDetail(const struct AdbcError* error, int index); + +int FlightSQLArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out); +int FlightSQLArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out); diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go index 73c34eae4f..d1c143a762 100644 --- a/go/adbc/pkg/panicdummy/driver.go +++ b/go/adbc/pkg/panicdummy/driver.go @@ -28,11 +28,22 @@ package main // #cgo CXXFLAGS: -std=c++11 -DADBC_EXPORTING // #include "../../drivermgr/adbc.h" // #include "utils.h" +// #include // #include // #include // // typedef const char cchar_t; // typedef const uint8_t cuint8_t; +// typedef const uint32_t cuint32_t; +// typedef const struct AdbcError ConstAdbcError; +// +// int PanicDummyArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +// int PanicDummyArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); +// const char* PanicDummyArrayStreamGetLastError(struct ArrowArrayStream*); +// void PanicDummyArrayStreamRelease(struct ArrowArrayStream*); +// +// int PanicDummyArrayStreamGetSchemaTrampoline(struct ArrowArrayStream*, struct ArrowSchema*); +// int PanicDummyArrayStreamGetNextTrampoline(struct ArrowArrayStream*, struct ArrowArray*); // // void releasePartitions(struct AdbcPartitions* partitions); // @@ -51,6 +62,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc/driver/panicdummy" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/cdata" + "github.com/apache/arrow/go/v13/arrow/memory" "github.com/apache/arrow/go/v13/arrow/memory/mallocator" ) @@ -78,14 +90,63 @@ func setErr(err *C.struct_AdbcError, format string, vals ...interface{}) { err.release = (*[0]byte)(C.PanicDummy_release_error) } +func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) { + if err == nil { + return + } + + if err.vendor_code != C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA { + setErr(err, adbcError.Msg) + return + } + + cErrPtr := C.malloc(C.sizeof_struct_PanicDummyError) + cErr := (*C.struct_PanicDummyError)(cErrPtr) + cErr.message = C.CString(adbcError.Msg) + err.message = cErr.message + err.release = (*[0]byte)(C.PanicDummyReleaseErrWithDetails) + err.private_data = cErrPtr + + numDetails := len(adbcError.Details) + if numDetails > 0 { + cErr.keys = (**C.cchar_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cchar_t)(nil))))) + cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil))))) + cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t)) + + keys := fromCArr[*C.cchar_t](cErr.keys, numDetails) + values := fromCArr[*C.cuint8_t](cErr.values, numDetails) + lengths := fromCArr[C.size_t](cErr.lengths, numDetails) + + for i, detail := range adbcError.Details { + keys[i] = C.CString(detail.Key()) + bytes, err := detail.Serialize() + if err != nil { + msg := err.Error() + values[i] = (*C.cuint8_t)(unsafe.Pointer(C.CString(msg))) + lengths[i] = C.size_t(len(msg)) + } else { + values[i] = (*C.cuint8_t)(C.malloc(C.size_t(len(bytes)))) + sink := fromCArr[byte]((*byte)(values[i]), len(bytes)) + copy(sink, bytes) + lengths[i] = C.size_t(len(bytes)) + } + } + } else { + cErr.keys = nil + cErr.values = nil + cErr.lengths = nil + } + cErr.count = C.int(numDetails) +} + func errToAdbcErr(adbcerr *C.struct_AdbcError, err error) adbc.Status { - if adbcerr == nil || err == nil { + if err == nil { return adbc.StatusOK } var adbcError adbc.Error if errors.As(err, &adbcError) { - setErr(adbcerr, adbcError.Msg) + setErrWithDetails(adbcerr, adbcError) return adbcError.Code } @@ -123,6 +184,45 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T { return cgo.Handle((uintptr)(*hptr)).Value().(*T) } +func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode { + lenWithTerminator := C.size_t(len(val) + 1) + if lenWithTerminator <= *length { + sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length)) + copy(sink, val) + sink[lenWithTerminator] = 0 + } + *length = lenWithTerminator + return C.ADBC_STATUS_OK +} + +func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode { + if C.size_t(len(val)) <= *length { + sink := fromCArr[byte]((*byte)(out), int(*length)) + copy(sink, val) + } + *length = C.size_t(len(val)) + return C.ADBC_STATUS_OK +} + +type cancellableContext struct { + ctx context.Context + cancel context.CancelFunc +} + +func (c *cancellableContext) newContext() context.Context { + c.cancelContext() + c.ctx, c.cancel = context.WithCancel(context.Background()) + return c.ctx +} + +func (c *cancellableContext) cancelContext() { + if c.cancel != nil { + c.cancel() + } + c.ctx = nil + c.cancel = nil +} + func checkDBAlloc(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) @@ -152,48 +252,243 @@ func checkDBInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname strin return cdb } +// Custom ArrowArrayStream export to support ADBC error data in ArrowArrayStream + +type cArrayStream struct { + rdr array.RecordReader + // Must be C-allocated + adbcErr *C.struct_AdbcError + status C.AdbcStatusCode +} + +func (cStream *cArrayStream) maybeError() C.int { + err := cStream.rdr.Err() + if err != nil { + if cStream.adbcErr != nil { + C.PanicDummyerrRelease(cStream.adbcErr) + } else { + cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + } + cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA + cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) + switch adbc.Status(cStream.status) { + case adbc.StatusUnknown: + return C.EIO + case adbc.StatusNotImplemented: + return C.ENOTSUP + case adbc.StatusNotFound: + return C.ENOENT + case adbc.StatusAlreadyExists: + return C.EEXIST + case adbc.StatusInvalidArgument: + return C.EINVAL + case adbc.StatusInvalidState: + return C.EINVAL + case adbc.StatusInvalidData: + return C.EIO + case adbc.StatusIntegrity: + return C.EIO + case adbc.StatusInternal: + return C.EIO + case adbc.StatusIO: + return C.EIO + case adbc.StatusCancelled: + return C.ECANCELED + case adbc.StatusTimeout: + return C.ETIMEDOUT + case adbc.StatusUnauthenticated: + return C.EACCES + case adbc.StatusUnauthorized: + return C.EACCES + default: + return C.EIO + } + } + return 0 +} + +//export PanicDummyArrayStreamGetLastError +func PanicDummyArrayStreamGetLastError(stream *C.struct_ArrowArrayStream) *C.cchar_t { + if stream == nil || stream.release != (*[0]byte)(C.PanicDummyArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.adbcErr != nil { + return cStream.adbcErr.message + } + return nil +} + +//export PanicDummyArrayStreamGetNext +func PanicDummyArrayStreamGetNext(stream *C.struct_ArrowArrayStream, array *C.struct_ArrowArray) C.int { + if stream == nil || stream.release != (*[0]byte)(C.PanicDummyArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.rdr.Next() { + cdata.ExportArrowRecordBatch(cStream.rdr.Record(), toCdataArray(array), nil) + return 0 + } + array.release = nil + array.private_data = nil + return cStream.maybeError() +} + +//export PanicDummyArrayStreamGetSchema +func PanicDummyArrayStreamGetSchema(stream *C.struct_ArrowArrayStream, schema *C.struct_ArrowSchema) C.int { + if stream == nil || stream.release != (*[0]byte)(C.PanicDummyArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + s := cStream.rdr.Schema() + if s == nil { + return cStream.maybeError() + } + cdata.ExportArrowSchema(s, toCdataSchema(schema)) + return 0 +} + +//export PanicDummyArrayStreamRelease +func PanicDummyArrayStreamRelease(stream *C.struct_ArrowArrayStream) { + if stream == nil || stream.release != (*[0]byte)(C.PanicDummyArrayStreamRelease) { + return + } + h := (*(*cgo.Handle)(stream.private_data)) + + cStream := h.Value().(*cArrayStream) + cStream.rdr.Release() + if cStream.adbcErr != nil { + C.PanicDummyerrRelease(cStream.adbcErr) + C.free(unsafe.Pointer(cStream.adbcErr)) + } + C.free(unsafe.Pointer(stream.private_data)) + stream.private_data = nil + h.Delete() + runtime.GC() +} + +//export PanicDummyErrorFromArrayStream +func PanicDummyErrorFromArrayStream(stream *C.struct_ArrowArrayStream, status *C.AdbcStatusCode) *C.struct_AdbcError { + if stream == nil || stream.release != (*[0]byte)(C.PanicDummyArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if status != nil { + *status = cStream.status + } + return cStream.adbcErr +} + +func exportRecordReader(rdr array.RecordReader, stream *C.struct_ArrowArrayStream) { + cStream := &cArrayStream{rdr: rdr, status: C.ADBC_STATUS_OK} + stream.get_last_error = (*[0]byte)(C.PanicDummyArrayStreamGetLastError) + stream.get_next = (*[0]byte)(C.PanicDummyArrayStreamGetNextTrampoline) + stream.get_schema = (*[0]byte)(C.PanicDummyArrayStreamGetSchemaTrampoline) + stream.release = (*[0]byte)(C.PanicDummyArrayStreamRelease) + hndl := cgo.NewHandle(cStream) + stream.private_data = createHandle(hndl) + rdr.Retain() +} + type cDatabase struct { opts map[string]string db adbc.Database } -//export PanicDummyDatabaseNew -func PanicDummyDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export PanicDummyDatabaseGetOption +func PanicDummyDatabaseGetOption(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseNew", e) + code = poison(err, "AdbcDatabaseGetOption", e) } }() - if atomic.LoadInt32(&globalPoison) != 0 { - setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") - return C.ADBC_STATUS_INTERNAL + cdb := checkDBInit(db, err, "AdbcDatabaseGetOption") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE } - if db.private_data != nil { - setErr(err, "AdbcDatabaseNew: database already allocated") + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export PanicDummyDatabaseGetOptionBytes +func PanicDummyDatabaseGetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionBytes") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - dbobj := &cDatabase{opts: make(map[string]string)} - hndl := cgo.NewHandle(dbobj) - db.private_data = createHandle(hndl) - return C.ADBC_STATUS_OK + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) } -//export PanicDummyDatabaseSetOption -func PanicDummyDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export PanicDummyDatabaseGetOptionDouble +func PanicDummyDatabaseGetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseSetOption", e) + code = poison(err, "AdbcDatabaseGetOptionDouble", e) } }() - if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionDouble") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - cdb := getFromHandle[cDatabase](db.private_data) - k, v := C.GoString(key), C.GoString(value) - cdb.opts[k] = v + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return C.ADBC_STATUS_OK + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export PanicDummyDatabaseGetOptionInt +func PanicDummyDatabaseGetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export PanicDummyDatabaseInit @@ -222,6 +517,27 @@ func PanicDummyDatabaseInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) return C.ADBC_STATUS_OK } +//export PanicDummyDatabaseNew +func PanicDummyDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseNew", e) + } + }() + if atomic.LoadInt32(&globalPoison) != 0 { + setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") + return C.ADBC_STATUS_INTERNAL + } + if db.private_data != nil { + setErr(err, "AdbcDatabaseNew: database already allocated") + return C.ADBC_STATUS_INVALID_STATE + } + dbobj := &cDatabase{opts: make(map[string]string)} + hndl := cgo.NewHandle(dbobj) + db.private_data = createHandle(hndl) + return C.ADBC_STATUS_OK +} + //export PanicDummyDatabaseRelease func PanicDummyDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -250,7 +566,99 @@ func PanicDummyDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcErro return C.ADBC_STATUS_OK } +//export PanicDummyDatabaseSetOption +func PanicDummyDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOption", e) + } + }() + if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + return C.ADBC_STATUS_INVALID_STATE + } + cdb := getFromHandle[cDatabase](db.private_data) + + k, v := C.GoString(key), C.GoString(value) + if cdb.db != nil { + opts, ok := cdb.db.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(k, v))) + } else { + cdb.opts[k] = v + } + + return C.ADBC_STATUS_OK +} + +//export PanicDummyDatabaseSetOptionBytes +func PanicDummyDatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionBytes") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export PanicDummyDatabaseSetOptionDouble +func PanicDummyDatabaseSetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionDouble", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionDouble") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export PanicDummyDatabaseSetOptionInt +func PanicDummyDatabaseSetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) +} + type cConn struct { + cancellableContext + cnxn adbc.Connection initArgs map[string]string } @@ -284,6 +692,102 @@ func checkConnInit(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname return conn } +//export PanicDummyConnectionGetOption +func PanicDummyConnectionGetOption(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOption", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOption") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export PanicDummyConnectionGetOptionBytes +func PanicDummyConnectionGetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export PanicDummyConnectionGetOptionDouble +func PanicDummyConnectionGetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export PanicDummyConnectionGetOptionInt +func PanicDummyConnectionGetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + //export PanicDummyConnectionNew func PanicDummyConnectionNew(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -327,13 +831,75 @@ func PanicDummyConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.cc return C.ADBC_STATUS_OK } - opts, ok := conn.cnxn.(adbc.PostInitOptions) + opts, ok := conn.cnxn.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcConnectionSetOption: not supported post-init") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))) +} + +//export PanicDummyConnectionSetOptionBytes +func PanicDummyConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export PanicDummyConnectionSetOptionDouble +func PanicDummyConnectionSetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export PanicDummyConnectionSetOptionInt +func PanicDummyConnectionSetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) if !ok { - setErr(err, "AdbcConnectionSetOption: not supported post-init") + setErr(err, "AdbcConnectionSetOptionInt: options are not supported") return C.ADBC_STATUS_NOT_IMPLEMENTED } - rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val))) - return C.AdbcStatusCode(rawCode) + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export PanicDummyConnectionInit @@ -396,8 +962,9 @@ func PanicDummyConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_Ad conn := h.Value().(*cConn) defer func() { + conn.cancelContext() conn.cnxn = nil - C.free(unsafe.Pointer(cnxn.private_data)) + C.free(cnxn.private_data) cnxn.private_data = nil h.Delete() // manually trigger GC for two reasons: @@ -434,26 +1001,19 @@ func toCdataArray(ptr *C.struct_ArrowArray) *cdata.CArrowArray { return (*cdata.CArrowArray)(unsafe.Pointer(ptr)) } -//export PanicDummyConnectionGetInfo -func PanicDummyConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.uint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export PanicDummyConnectionCancel +func PanicDummyConnectionCancel(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcConnectionGetInfo", e) + code = poison(err, "AdbcConnectionCancel", e) } }() - conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + conn := checkConnInit(cnxn, err, "AdbcConnectionCancel") if conn == nil { return C.ADBC_STATUS_INVALID_STATE } - infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) - rdr, e := conn.cnxn.GetInfo(context.Background(), infoCodes) - if e != nil { - return C.AdbcStatusCode(errToAdbcErr(err, e)) - } - - defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + conn.cancelContext() return C.ADBC_STATUS_OK } @@ -481,6 +1041,29 @@ func toStrSlice(in **C.cchar_t) []string { return out } +//export PanicDummyConnectionGetInfo +func PanicDummyConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetInfo", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) + rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + //export PanicDummyConnectionGetObjects func PanicDummyConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, catalog, dbSchema, tableName *C.cchar_t, tableType **C.cchar_t, columnName *C.cchar_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { @@ -494,12 +1077,67 @@ func PanicDummyConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetObjects(context.Background(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + rdr, e := conn.cnxn.GetObjects(conn.newContext(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export PanicDummyConnectionGetStatistics +func PanicDummyConnectionGetStatistics(cnxn *C.struct_AdbcConnection, catalog, dbSchema, tableName *C.cchar_t, approximate C.char, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatistics(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), int(approximate) != 0) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export PanicDummyConnectionGetStatisticNames +func PanicDummyConnectionGetStatisticNames(cnxn *C.struct_AdbcConnection, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatisticNames(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -515,7 +1153,7 @@ func PanicDummyConnectionGetTableSchema(cnxn *C.struct_AdbcConnection, catalog, return C.ADBC_STATUS_INVALID_STATE } - sc, e := conn.cnxn.GetTableSchema(context.Background(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) + sc, e := conn.cnxn.GetTableSchema(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -535,12 +1173,12 @@ func PanicDummyConnectionGetTableTypes(cnxn *C.struct_AdbcConnection, out *C.str return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetTableTypes(context.Background()) + rdr, e := conn.cnxn.GetTableTypes(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -556,12 +1194,12 @@ func PanicDummyConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialized return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.ReadPartition(context.Background(), fromCArr[byte](serialized, int(serializedLen))) + rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen))) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -577,7 +1215,7 @@ func PanicDummyConnectionCommit(cnxn *C.struct_AdbcConnection, err *C.struct_Adb return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(conn.newContext()))) } //export PanicDummyConnectionRollback @@ -592,25 +1230,137 @@ func PanicDummyConnectionRollback(cnxn *C.struct_AdbcConnection, err *C.struct_A return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(conn.newContext()))) +} + +type cStmt struct { + cancellableContext + + stmt adbc.Statement } -func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) adbc.Statement { +func checkStmtAlloc(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) - return nil + return false } if stmt == nil { setErr(err, "%s: statement not allocated", fname) - return nil + return false } - if stmt.private_data == nil { - setErr(err, "%s: statement not initialized", fname) + setErr(err, "%s: statement not allocated", fname) + return false + } + return true +} + +func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) *cStmt { + if !checkStmtAlloc(stmt, err, fname) { + return nil + } + cStmt := getFromHandle[cStmt](stmt.private_data) + if cStmt.stmt == nil { + setErr(err, "%s: statement not allocated", fname) return nil } + return cStmt +} + +//export PanicDummyStatementGetOption +func PanicDummyStatementGetOption(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOption", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOption") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export PanicDummyStatementGetOptionBytes +func PanicDummyStatementGetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export PanicDummyStatementGetOptionDouble +func PanicDummyStatementGetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export PanicDummyStatementGetOptionInt +func PanicDummyStatementGetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return (*(*cgo.Handle)(stmt.private_data)).Value().(adbc.Statement) + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export PanicDummyStatementNew @@ -639,8 +1389,8 @@ func PanicDummyStatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcSt return C.AdbcStatusCode(errToAdbcErr(err, e)) } - h := cgo.NewHandle(st) - stmt.private_data = createHandle(h) + hndl := cgo.NewHandle(&cStmt{stmt: st}) + stmt.private_data = createHandle(hndl) return C.ADBC_STATUS_OK } @@ -655,31 +1405,46 @@ func PanicDummyStatementRelease(stmt *C.struct_AdbcStatement, err *C.struct_Adbc setErr(err, "AdbcStatementRelease: Go panicked, driver is in unknown state") return C.ADBC_STATUS_INTERNAL } - if stmt == nil { - setErr(err, "AdbcStatementRelease: statement not allocated") + if !checkStmtAlloc(stmt, err, "AdbcStatementRelease") { return C.ADBC_STATUS_INVALID_STATE } + h := (*(*cgo.Handle)(stmt.private_data)) - if stmt.private_data == nil { - setErr(err, "AdbcStatementRelease: statement not initialized") - return C.ADBC_STATUS_INVALID_STATE + st := h.Value().(*cStmt) + defer func() { + st.cancelContext() + st.stmt = nil + C.free(stmt.private_data) + stmt.private_data = nil + h.Delete() + // manually trigger GC for two reasons: + // 1. ASAN expects the release callback to be called before + // the process ends, but GC is not deterministic. So by manually + // triggering the GC we ensure the release callback gets called. + // 2. Creates deterministic GC behavior by all Release functions + // triggering a garbage collection + runtime.GC() + }() + if st.stmt == nil { + return C.ADBC_STATUS_OK } + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Close())) +} - h := (*(*cgo.Handle)(stmt.private_data)) - st := h.Value().(adbc.Statement) - C.free(stmt.private_data) - stmt.private_data = nil +//export PanicDummyStatementCancel +func PanicDummyStatementCancel(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementCancel", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementCancel") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } - e := st.Close() - h.Delete() - // manually trigger GC for two reasons: - // 1. ASAN expects the release callback to be called before - // the process ends, but GC is not deterministic. So by manually - // triggering the GC we ensure the release callback gets called. - // 2. Creates deterministic GC behavior by all Release functions - // triggering a garbage collection - runtime.GC() - return C.AdbcStatusCode(errToAdbcErr(err, e)) + st.cancelContext() + return C.ADBC_STATUS_OK } //export PanicDummyStatementPrepare @@ -694,7 +1459,7 @@ func PanicDummyStatementPrepare(stmt *C.struct_AdbcStatement, err *C.struct_Adbc return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.Prepare(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Prepare(st.newContext()))) } //export PanicDummyStatementExecuteQuery @@ -710,7 +1475,7 @@ func PanicDummyStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct } if out == nil { - n, e := st.ExecuteUpdate(context.Background()) + n, e := st.stmt.ExecuteUpdate(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -719,7 +1484,7 @@ func PanicDummyStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct *affected = C.int64_t(n) } } else { - rdr, n, e := st.ExecuteQuery(context.Background()) + rdr, n, e := st.stmt.ExecuteQuery(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -729,8 +1494,35 @@ func PanicDummyStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) + } + return C.ADBC_STATUS_OK +} + +//export PanicDummyStatementExecuteSchema +func PanicDummyStatementExecuteSchema(stmt *C.struct_AdbcStatement, schema *C.struct_ArrowSchema, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementExecuteQuery", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementExecuteQuery") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + es, ok := st.stmt.(adbc.StatementExecuteSchema) + if !ok { + setErr(err, "AdbcStatementExecuteSchema: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + sc, e := es.ExecuteSchema(st.newContext()) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) } + + cdata.ExportArrowSchema(sc, toCdataSchema(schema)) return C.ADBC_STATUS_OK } @@ -746,7 +1538,7 @@ func PanicDummyStatementSetSqlQuery(stmt *C.struct_AdbcStatement, query *C.cchar return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSqlQuery(C.GoString(query)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSqlQuery(C.GoString(query)))) } //export PanicDummyStatementSetSubstraitPlan @@ -761,7 +1553,7 @@ func PanicDummyStatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C.c return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) } //export PanicDummyStatementBind @@ -784,7 +1576,7 @@ func PanicDummyStatementBind(stmt *C.struct_AdbcStatement, values *C.struct_Arro } defer rec.Release() - return C.AdbcStatusCode(errToAdbcErr(err, st.Bind(context.Background(), rec))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Bind(st.newContext(), rec))) } //export PanicDummyStatementBindStream @@ -803,7 +1595,7 @@ func PanicDummyStatementBindStream(stmt *C.struct_AdbcStatement, stream *C.struc if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } - return C.AdbcStatusCode(errToAdbcErr(err, st.BindStream(context.Background(), rdr.(array.RecordReader)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.BindStream(st.newContext(), rdr.(array.RecordReader)))) } //export PanicDummyStatementGetParameterSchema @@ -818,7 +1610,7 @@ func PanicDummyStatementGetParameterSchema(stmt *C.struct_AdbcStatement, schema return C.ADBC_STATUS_INVALID_STATE } - sc, e := st.GetParameterSchema() + sc, e := st.stmt.GetParameterSchema() if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -839,7 +1631,70 @@ func PanicDummyStatementSetOption(stmt *C.struct_AdbcStatement, key, value *C.cc return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetOption(C.GoString(key), C.GoString(value)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetOption(C.GoString(key), C.GoString(value)))) +} + +//export PanicDummyStatementSetOptionBytes +func PanicDummyStatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export PanicDummyStatementSetOptionDouble +func PanicDummyStatementSetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export PanicDummyStatementSetOptionInt +func PanicDummyStatementSetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export releasePartitions @@ -868,7 +1723,7 @@ func PanicDummyStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema * return C.ADBC_STATUS_INVALID_STATE } - sc, part, n, e := st.ExecutePartitions(context.Background()) + sc, part, n, e := st.stmt.ExecutePartitions(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -913,13 +1768,20 @@ func PanicDummyStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema * //export PanicDummyDriverInit func PanicDummyDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcError) C.AdbcStatusCode { - if version != C.ADBC_VERSION_1_0_0 { - setErr(err, "Only version %d supported, got %d", int(C.ADBC_VERSION_1_0_0), int(version)) + driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) + + switch version { + case C.ADBC_VERSION_1_0_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE) + memory.Set(sink, 0) + case C.ADBC_VERSION_1_1_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE) + memory.Set(sink, 0) + default: + setErr(err, "Only version 1.0.0/1.1.0 supported, got %d", int(version)) return C.ADBC_STATUS_NOT_IMPLEMENTED } - driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) - C.memset(unsafe.Pointer(driver), 0, C.sizeof_struct_AdbcDriver) driver.DatabaseInit = (*[0]byte)(C.PanicDummyDatabaseInit) driver.DatabaseNew = (*[0]byte)(C.PanicDummyDatabaseNew) driver.DatabaseRelease = (*[0]byte)(C.PanicDummyDatabaseRelease) @@ -949,6 +1811,41 @@ func PanicDummyDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcEr driver.StatementGetParameterSchema = (*[0]byte)(C.PanicDummyStatementGetParameterSchema) driver.StatementPrepare = (*[0]byte)(C.PanicDummyStatementPrepare) + if version == C.ADBC_VERSION_1_1_0 { + driver.ErrorGetDetailCount = (*[0]byte)(C.PanicDummyErrorGetDetailCount) + driver.ErrorGetDetail = (*[0]byte)(C.PanicDummyErrorGetDetail) + driver.ErrorFromArrayStream = (*[0]byte)(C.PanicDummyErrorFromArrayStream) + + driver.DatabaseGetOption = (*[0]byte)(C.PanicDummyDatabaseGetOption) + driver.DatabaseGetOptionBytes = (*[0]byte)(C.PanicDummyDatabaseGetOptionBytes) + driver.DatabaseGetOptionDouble = (*[0]byte)(C.PanicDummyDatabaseGetOptionDouble) + driver.DatabaseGetOptionInt = (*[0]byte)(C.PanicDummyDatabaseGetOptionInt) + driver.DatabaseSetOptionBytes = (*[0]byte)(C.PanicDummyDatabaseSetOptionBytes) + driver.DatabaseSetOptionDouble = (*[0]byte)(C.PanicDummyDatabaseSetOptionDouble) + driver.DatabaseSetOptionInt = (*[0]byte)(C.PanicDummyDatabaseSetOptionInt) + + driver.ConnectionCancel = (*[0]byte)(C.PanicDummyConnectionCancel) + driver.ConnectionGetOption = (*[0]byte)(C.PanicDummyConnectionGetOption) + driver.ConnectionGetOptionBytes = (*[0]byte)(C.PanicDummyConnectionGetOptionBytes) + driver.ConnectionGetOptionDouble = (*[0]byte)(C.PanicDummyConnectionGetOptionDouble) + driver.ConnectionGetOptionInt = (*[0]byte)(C.PanicDummyConnectionGetOptionInt) + driver.ConnectionGetStatistics = (*[0]byte)(C.PanicDummyConnectionGetStatistics) + driver.ConnectionGetStatisticNames = (*[0]byte)(C.PanicDummyConnectionGetStatisticNames) + driver.ConnectionSetOptionBytes = (*[0]byte)(C.PanicDummyConnectionSetOptionBytes) + driver.ConnectionSetOptionDouble = (*[0]byte)(C.PanicDummyConnectionSetOptionDouble) + driver.ConnectionSetOptionInt = (*[0]byte)(C.PanicDummyConnectionSetOptionInt) + + driver.StatementCancel = (*[0]byte)(C.PanicDummyStatementCancel) + driver.StatementExecuteSchema = (*[0]byte)(C.PanicDummyStatementExecuteSchema) + driver.StatementGetOption = (*[0]byte)(C.PanicDummyStatementGetOption) + driver.StatementGetOptionBytes = (*[0]byte)(C.PanicDummyStatementGetOptionBytes) + driver.StatementGetOptionDouble = (*[0]byte)(C.PanicDummyStatementGetOptionDouble) + driver.StatementGetOptionInt = (*[0]byte)(C.PanicDummyStatementGetOptionInt) + driver.StatementSetOptionBytes = (*[0]byte)(C.PanicDummyStatementSetOptionBytes) + driver.StatementSetOptionDouble = (*[0]byte)(C.PanicDummyStatementSetOptionDouble) + driver.StatementSetOptionInt = (*[0]byte)(C.PanicDummyStatementSetOptionInt) + } + return C.ADBC_STATUS_OK } diff --git a/go/adbc/pkg/panicdummy/utils.c b/go/adbc/pkg/panicdummy/utils.c index d0a2936618..526cd27ca4 100644 --- a/go/adbc/pkg/panicdummy/utils.c +++ b/go/adbc/pkg/panicdummy/utils.c @@ -35,52 +35,142 @@ void PanicDummy_release_error(struct AdbcError* error) { error->release = NULL; } -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return PanicDummyDatabaseNew(database, error); +void PanicDummyReleaseErrWithDetails(struct AdbcError* error) { + if (!error || error->release != PanicDummyReleaseErrWithDetails || + !error->private_data) { + return; + } + + struct PanicDummyError* details = + (struct PanicDummyError*) error->private_data; + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + free(details->keys); + free(details->values); + free(details->lengths); + free(details); + + free(error->message); + error->message = NULL; + error->release = NULL; + error->private_data = NULL; } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return PanicDummyDatabaseSetOption(database, key, value, error); +int PanicDummyErrorGetDetailCount(const struct AdbcError* error) { + if (!error || error->release != PanicDummyReleaseErrWithDetails || + !error->private_data) { + return 0; + } + + return ((struct PanicDummyError*) error->private_data)->count; +} + +struct AdbcErrorDetail PanicDummyErrorGetDetail(const struct AdbcError* error, + int index) { + if (!error || error->release != PanicDummyReleaseErrWithDetails || + !error->private_data) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct PanicDummyError* details = (struct PanicDummyError*) error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index] + }; +} + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return PanicDummyErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return PanicDummyErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return PanicDummyErrorFromArrayStream(stream, status); +} + +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PanicDummyDatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return PanicDummyDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return PanicDummyDatabaseGetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return PanicDummyDatabaseGetOptionInt(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return PanicDummyDatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return PanicDummyDatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return PanicDummyDatabaseRelease(database, error); } -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, - struct AdbcError* error) { - return PanicDummyConnectionNew(connection, error); +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return PanicDummyDatabaseSetOption(database, key, value, error); } -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, - const char* value, struct AdbcError* error) { - return PanicDummyConnectionSetOption(connection, key, value, error); +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return PanicDummyDatabaseSetOptionBytes(database, key, value, length, error); } -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, - struct AdbcDatabase* database, - struct AdbcError* error) { - return PanicDummyConnectionInit(connection, database, error); +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return PanicDummyDatabaseSetOptionDouble(database, key, value, error); } -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, - struct AdbcError* error) { - return PanicDummyConnectionRelease(connection, error); +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return PanicDummyDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return PanicDummyConnectionCancel(connection, error); +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + return PanicDummyConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); - return PanicDummyConnectionGetInfo(connection, info_codes, info_codes_length, out, - error); + return PanicDummyConnectionGetInfo(connection, info_codes, info_codes_length, + out, error); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -91,7 +181,46 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return PanicDummyConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_type, column_name, out, error); + table_type, column_name, out, error); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PanicDummyConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PanicDummyConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return PanicDummyConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return PanicDummyConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return PanicDummyConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return PanicDummyConnectionGetStatisticNames(connection, out, error); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -101,7 +230,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, struct AdbcError* error) { if (schema) memset(schema, 0, sizeof(*schema)); return PanicDummyConnectionGetTableSchema(connection, catalog, db_schema, table_name, - schema, error); + schema, error); } AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, @@ -111,6 +240,17 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, return PanicDummyConnectionGetTableTypes(connection, out, error); } +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return PanicDummyConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + return PanicDummyConnectionNew(connection, error); +} + AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, const uint8_t* serialized_partition, size_t serialized_length, @@ -118,12 +258,12 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return PanicDummyConnectionReadPartition(connection, serialized_partition, - serialized_length, out, error); + serialized_length, out, error); } -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, - struct AdbcError* error) { - return PanicDummyConnectionCommit(connection, error); +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + return PanicDummyConnectionRelease(connection, error); } AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, @@ -131,39 +271,32 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return PanicDummyConnectionRollback(connection, error); } -AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, - struct AdbcStatement* statement, - struct AdbcError* error) { - return PanicDummyStatementNew(connection, statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, - struct AdbcError* error) { - return PanicDummyStatementRelease(statement, error); +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error) { + return PanicDummyConnectionSetOption(connection, key, value, error); } -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, - struct ArrowArrayStream* out, - int64_t* rows_affected, - struct AdbcError* error) { - if (out) memset(out, 0, sizeof(*out)); - return PanicDummyStatementExecuteQuery(statement, out, rows_affected, error); +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PanicDummyConnectionSetOptionBytes(connection, key, value, length, error); } -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, - struct AdbcError* error) { - return PanicDummyStatementPrepare(statement, error); +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return PanicDummyConnectionSetOptionDouble(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, - const char* query, struct AdbcError* error) { - return PanicDummyStatementSetSqlQuery(statement, query, error); +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return PanicDummyConnectionSetOptionInt(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, - const uint8_t* plan, size_t length, - struct AdbcError* error) { - return PanicDummyStatementSetSubstraitPlan(statement, plan, length, error); +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return PanicDummyStatementCancel(statement, error); } AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, @@ -178,6 +311,56 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return PanicDummyStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + if (partitions) memset(partitions, 0, sizeof(*partitions)); + return PanicDummyStatementExecutePartitions(statement, schema, partitions, + rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + if (out) memset(out, 0, sizeof(*out)); + return PanicDummyStatementExecuteQuery(statement, out, rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + return PanicDummyStatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PanicDummyStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PanicDummyStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return PanicDummyStatementGetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + return PanicDummyStatementGetOptionInt(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -185,20 +368,54 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, return PanicDummyStatementGetParameterSchema(statement, schema, error); } +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + return PanicDummyStatementNew(connection, statement, error); +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + return PanicDummyStatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + return PanicDummyStatementRelease(statement, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + return PanicDummyStatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + return PanicDummyStatementSetSubstraitPlan(statement, plan, length, error); +} + AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error) { return PanicDummyStatementSetOption(statement, key, value, error); } -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, - struct ArrowSchema* schema, - struct AdbcPartitions* partitions, - int64_t* rows_affected, - struct AdbcError* error) { - if (schema) memset(schema, 0, sizeof(*schema)); - if (partitions) memset(partitions, 0, sizeof(*partitions)); - return PanicDummyStatementExecutePartitions(statement, schema, partitions, - rows_affected, error); +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PanicDummyStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return PanicDummyStatementSetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + return PanicDummyStatementSetOptionInt(statement, key, value, error); } ADBC_EXPORT @@ -206,6 +423,23 @@ AdbcStatusCode AdbcDriverInit(int version, void* driver, struct AdbcError* error return PanicDummyDriverInit(version, driver, error); } +int PanicDummyArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +int PanicDummyArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); + +int PanicDummyArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return PanicDummyArrayStreamGetSchema(stream, out); +} + +int PanicDummyArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return PanicDummyArrayStreamGetNext(stream, out); +} + #ifdef __cplusplus } #endif diff --git a/go/adbc/pkg/panicdummy/utils.h b/go/adbc/pkg/panicdummy/utils.h index f3b3aae13e..91d8294c4e 100644 --- a/go/adbc/pkg/panicdummy/utils.h +++ b/go/adbc/pkg/panicdummy/utils.h @@ -26,71 +26,133 @@ #include #include "../../drivermgr/adbc.h" +struct AdbcError* PanicDummyErrorFromArrayStream(struct ArrowArrayStream*, + AdbcStatusCode*); +AdbcStatusCode PanicDummyDatabaseGetOption(struct AdbcDatabase*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseGetOptionBytes(struct AdbcDatabase*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseGetOptionDouble(struct AdbcDatabase*, const char*, + double*, struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseGetOptionInt(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode PanicDummyDatabaseNew(struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode PanicDummyDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode PanicDummyDatabaseSetOption(struct AdbcDatabase* db, const char* key, const char* value, struct AdbcError* err); -AdbcStatusCode PanicDummyDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode PanicDummyDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionNew(struct AdbcConnection* cnxn, - struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionSetOption(struct AdbcConnection* cnxn, const char* key, - const char* val, struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionInit(struct AdbcConnection* cnxn, - struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionRelease(struct AdbcConnection* cnxn, - struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionGetInfo(struct AdbcConnection* cnxn, uint32_t* codes, - size_t len, struct ArrowArrayStream* out, +AdbcStatusCode PanicDummyDatabaseSetOptionBytes(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseSetOptionDouble(struct AdbcDatabase*, const char*, + double, struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseSetOptionInt(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + +AdbcStatusCode PanicDummyConnectionCancel(struct AdbcConnection*, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionCommit(struct AdbcConnection* cnxn, + struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionGetInfo(struct AdbcConnection* cnxn, + const uint32_t* codes, size_t len, + struct ArrowArrayStream* out, struct AdbcError* err); AdbcStatusCode PanicDummyConnectionGetObjects( struct AdbcConnection* cnxn, int depth, const char* catalog, const char* dbSchema, const char* tableName, const char** tableType, const char* columnName, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionGetOption(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionGetOptionBytes(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionGetOptionDouble(struct AdbcConnection*, const char*, + double*, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionGetOptionInt(struct AdbcConnection*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionGetStatistics(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, + struct AdbcError*); +AdbcStatusCode PanicDummyConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); AdbcStatusCode PanicDummyConnectionGetTableSchema( struct AdbcConnection* cnxn, const char* catalog, const char* dbSchema, const char* tableName, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode PanicDummyConnectionGetTableTypes(struct AdbcConnection* cnxn, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionInit(struct AdbcConnection* cnxn, + struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionNew(struct AdbcConnection* cnxn, + struct AdbcError* err); AdbcStatusCode PanicDummyConnectionReadPartition(struct AdbcConnection* cnxn, const uint8_t* serialized, size_t serializedLen, struct ArrowArrayStream* out, struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionCommit(struct AdbcConnection* cnxn, - struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionRelease(struct AdbcConnection* cnxn, + struct AdbcError* err); AdbcStatusCode PanicDummyConnectionRollback(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode PanicDummyStatementNew(struct AdbcConnection* cnxn, - struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode PanicDummyStatementRelease(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode PanicDummyStatementPrepare(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode PanicDummyStatementExecuteQuery(struct AdbcStatement* stmt, - struct ArrowArrayStream* out, - int64_t* affected, struct AdbcError* err); -AdbcStatusCode PanicDummyStatementSetSqlQuery(struct AdbcStatement* stmt, - const char* query, struct AdbcError* err); -AdbcStatusCode PanicDummyStatementSetSubstraitPlan(struct AdbcStatement* stmt, - const uint8_t* plan, size_t length, - struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionSetOption(struct AdbcConnection* cnxn, const char* key, + const char* val, struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode PanicDummyConnectionSetOptionDouble(struct AdbcConnection*, const char*, + double, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionSetOptionInt(struct AdbcConnection*, const char*, + int64_t, struct AdbcError*); + AdbcStatusCode PanicDummyStatementBind(struct AdbcStatement* stmt, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode PanicDummyStatementBindStream(struct AdbcStatement* stmt, struct ArrowArrayStream* stream, struct AdbcError* err); -AdbcStatusCode PanicDummyStatementGetParameterSchema(struct AdbcStatement* stmt, - struct ArrowSchema* schema, - struct AdbcError* err); -AdbcStatusCode PanicDummyStatementSetOption(struct AdbcStatement* stmt, const char* key, - const char* value, struct AdbcError* err); +AdbcStatusCode PanicDummyStatementCancel(struct AdbcStatement*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementExecuteQuery(struct AdbcStatement* stmt, + struct ArrowArrayStream* out, + int64_t* affected, struct AdbcError* err); AdbcStatusCode PanicDummyStatementExecutePartitions(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcPartitions* partitions, int64_t* affected, struct AdbcError* err); +AdbcStatusCode PanicDummyStatementExecuteSchema(struct AdbcStatement*, + struct ArrowSchema*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementGetOption(struct AdbcStatement*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementGetOptionBytes(struct AdbcStatement*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementGetOptionDouble(struct AdbcStatement*, const char*, + double*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementGetOptionInt(struct AdbcStatement*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementGetParameterSchema(struct AdbcStatement* stmt, + struct ArrowSchema* schema, + struct AdbcError* err); +AdbcStatusCode PanicDummyStatementNew(struct AdbcConnection* cnxn, + struct AdbcStatement* stmt, struct AdbcError* err); +AdbcStatusCode PanicDummyStatementPrepare(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode PanicDummyStatementRelease(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode PanicDummyStatementSetOption(struct AdbcStatement* stmt, const char* key, + const char* value, struct AdbcError* err); +AdbcStatusCode PanicDummyStatementSetOptionBytes(struct AdbcStatement*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode PanicDummyStatementSetOptionDouble(struct AdbcStatement*, const char*, + double, struct AdbcError*); +AdbcStatusCode PanicDummyStatementSetOptionInt(struct AdbcStatement*, const char*, + int64_t, struct AdbcError*); +AdbcStatusCode PanicDummyStatementSetSqlQuery(struct AdbcStatement* stmt, + const char* query, struct AdbcError* err); +AdbcStatusCode PanicDummyStatementSetSubstraitPlan(struct AdbcStatement* stmt, + const uint8_t* plan, size_t length, + struct AdbcError* err); + AdbcStatusCode PanicDummyDriverInit(int version, void* rawDriver, struct AdbcError* err); static inline void PanicDummyerrRelease(struct AdbcError* error) { @@ -98,3 +160,21 @@ static inline void PanicDummyerrRelease(struct AdbcError* error) { } void PanicDummy_release_error(struct AdbcError* error); + +struct PanicDummyError { + char* message; + char** keys; + uint8_t** values; + size_t* lengths; + int count; +}; + +void PanicDummyReleaseErrWithDetails(struct AdbcError* error); + +int PanicDummyErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail PanicDummyErrorGetDetail(const struct AdbcError* error, int index); + +int PanicDummyArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out); +int PanicDummyArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out); diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go index 51b6ce5bc9..6ca09646d4 100644 --- a/go/adbc/pkg/snowflake/driver.go +++ b/go/adbc/pkg/snowflake/driver.go @@ -28,11 +28,22 @@ package main // #cgo CXXFLAGS: -std=c++11 -DADBC_EXPORTING // #include "../../drivermgr/adbc.h" // #include "utils.h" +// #include // #include // #include // // typedef const char cchar_t; // typedef const uint8_t cuint8_t; +// typedef const uint32_t cuint32_t; +// typedef const struct AdbcError ConstAdbcError; +// +// int SnowflakeArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +// int SnowflakeArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); +// const char* SnowflakeArrayStreamGetLastError(struct ArrowArrayStream*); +// void SnowflakeArrayStreamRelease(struct ArrowArrayStream*); +// +// int SnowflakeArrayStreamGetSchemaTrampoline(struct ArrowArrayStream*, struct ArrowSchema*); +// int SnowflakeArrayStreamGetNextTrampoline(struct ArrowArrayStream*, struct ArrowArray*); // // void releasePartitions(struct AdbcPartitions* partitions); // @@ -51,6 +62,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc/driver/snowflake" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/cdata" + "github.com/apache/arrow/go/v13/arrow/memory" "github.com/apache/arrow/go/v13/arrow/memory/mallocator" ) @@ -78,14 +90,63 @@ func setErr(err *C.struct_AdbcError, format string, vals ...interface{}) { err.release = (*[0]byte)(C.Snowflake_release_error) } +func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) { + if err == nil { + return + } + + if err.vendor_code != C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA { + setErr(err, adbcError.Msg) + return + } + + cErrPtr := C.malloc(C.sizeof_struct_SnowflakeError) + cErr := (*C.struct_SnowflakeError)(cErrPtr) + cErr.message = C.CString(adbcError.Msg) + err.message = cErr.message + err.release = (*[0]byte)(C.SnowflakeReleaseErrWithDetails) + err.private_data = cErrPtr + + numDetails := len(adbcError.Details) + if numDetails > 0 { + cErr.keys = (**C.cchar_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cchar_t)(nil))))) + cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil))))) + cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t)) + + keys := fromCArr[*C.cchar_t](cErr.keys, numDetails) + values := fromCArr[*C.cuint8_t](cErr.values, numDetails) + lengths := fromCArr[C.size_t](cErr.lengths, numDetails) + + for i, detail := range adbcError.Details { + keys[i] = C.CString(detail.Key()) + bytes, err := detail.Serialize() + if err != nil { + msg := err.Error() + values[i] = (*C.cuint8_t)(unsafe.Pointer(C.CString(msg))) + lengths[i] = C.size_t(len(msg)) + } else { + values[i] = (*C.cuint8_t)(C.malloc(C.size_t(len(bytes)))) + sink := fromCArr[byte]((*byte)(values[i]), len(bytes)) + copy(sink, bytes) + lengths[i] = C.size_t(len(bytes)) + } + } + } else { + cErr.keys = nil + cErr.values = nil + cErr.lengths = nil + } + cErr.count = C.int(numDetails) +} + func errToAdbcErr(adbcerr *C.struct_AdbcError, err error) adbc.Status { - if adbcerr == nil || err == nil { + if err == nil { return adbc.StatusOK } var adbcError adbc.Error if errors.As(err, &adbcError) { - setErr(adbcerr, adbcError.Msg) + setErrWithDetails(adbcerr, adbcError) return adbcError.Code } @@ -123,6 +184,45 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T { return cgo.Handle((uintptr)(*hptr)).Value().(*T) } +func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode { + lenWithTerminator := C.size_t(len(val) + 1) + if lenWithTerminator <= *length { + sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length)) + copy(sink, val) + sink[lenWithTerminator] = 0 + } + *length = lenWithTerminator + return C.ADBC_STATUS_OK +} + +func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode { + if C.size_t(len(val)) <= *length { + sink := fromCArr[byte]((*byte)(out), int(*length)) + copy(sink, val) + } + *length = C.size_t(len(val)) + return C.ADBC_STATUS_OK +} + +type cancellableContext struct { + ctx context.Context + cancel context.CancelFunc +} + +func (c *cancellableContext) newContext() context.Context { + c.cancelContext() + c.ctx, c.cancel = context.WithCancel(context.Background()) + return c.ctx +} + +func (c *cancellableContext) cancelContext() { + if c.cancel != nil { + c.cancel() + } + c.ctx = nil + c.cancel = nil +} + func checkDBAlloc(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) @@ -152,48 +252,243 @@ func checkDBInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname strin return cdb } +// Custom ArrowArrayStream export to support ADBC error data in ArrowArrayStream + +type cArrayStream struct { + rdr array.RecordReader + // Must be C-allocated + adbcErr *C.struct_AdbcError + status C.AdbcStatusCode +} + +func (cStream *cArrayStream) maybeError() C.int { + err := cStream.rdr.Err() + if err != nil { + if cStream.adbcErr != nil { + C.SnowflakeerrRelease(cStream.adbcErr) + } else { + cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + } + cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA + cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) + switch adbc.Status(cStream.status) { + case adbc.StatusUnknown: + return C.EIO + case adbc.StatusNotImplemented: + return C.ENOTSUP + case adbc.StatusNotFound: + return C.ENOENT + case adbc.StatusAlreadyExists: + return C.EEXIST + case adbc.StatusInvalidArgument: + return C.EINVAL + case adbc.StatusInvalidState: + return C.EINVAL + case adbc.StatusInvalidData: + return C.EIO + case adbc.StatusIntegrity: + return C.EIO + case adbc.StatusInternal: + return C.EIO + case adbc.StatusIO: + return C.EIO + case adbc.StatusCancelled: + return C.ECANCELED + case adbc.StatusTimeout: + return C.ETIMEDOUT + case adbc.StatusUnauthenticated: + return C.EACCES + case adbc.StatusUnauthorized: + return C.EACCES + default: + return C.EIO + } + } + return 0 +} + +//export SnowflakeArrayStreamGetLastError +func SnowflakeArrayStreamGetLastError(stream *C.struct_ArrowArrayStream) *C.cchar_t { + if stream == nil || stream.release != (*[0]byte)(C.SnowflakeArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.adbcErr != nil { + return cStream.adbcErr.message + } + return nil +} + +//export SnowflakeArrayStreamGetNext +func SnowflakeArrayStreamGetNext(stream *C.struct_ArrowArrayStream, array *C.struct_ArrowArray) C.int { + if stream == nil || stream.release != (*[0]byte)(C.SnowflakeArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.rdr.Next() { + cdata.ExportArrowRecordBatch(cStream.rdr.Record(), toCdataArray(array), nil) + return 0 + } + array.release = nil + array.private_data = nil + return cStream.maybeError() +} + +//export SnowflakeArrayStreamGetSchema +func SnowflakeArrayStreamGetSchema(stream *C.struct_ArrowArrayStream, schema *C.struct_ArrowSchema) C.int { + if stream == nil || stream.release != (*[0]byte)(C.SnowflakeArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + s := cStream.rdr.Schema() + if s == nil { + return cStream.maybeError() + } + cdata.ExportArrowSchema(s, toCdataSchema(schema)) + return 0 +} + +//export SnowflakeArrayStreamRelease +func SnowflakeArrayStreamRelease(stream *C.struct_ArrowArrayStream) { + if stream == nil || stream.release != (*[0]byte)(C.SnowflakeArrayStreamRelease) { + return + } + h := (*(*cgo.Handle)(stream.private_data)) + + cStream := h.Value().(*cArrayStream) + cStream.rdr.Release() + if cStream.adbcErr != nil { + C.SnowflakeerrRelease(cStream.adbcErr) + C.free(unsafe.Pointer(cStream.adbcErr)) + } + C.free(unsafe.Pointer(stream.private_data)) + stream.private_data = nil + h.Delete() + runtime.GC() +} + +//export SnowflakeErrorFromArrayStream +func SnowflakeErrorFromArrayStream(stream *C.struct_ArrowArrayStream, status *C.AdbcStatusCode) *C.struct_AdbcError { + if stream == nil || stream.release != (*[0]byte)(C.SnowflakeArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if status != nil { + *status = cStream.status + } + return cStream.adbcErr +} + +func exportRecordReader(rdr array.RecordReader, stream *C.struct_ArrowArrayStream) { + cStream := &cArrayStream{rdr: rdr, status: C.ADBC_STATUS_OK} + stream.get_last_error = (*[0]byte)(C.SnowflakeArrayStreamGetLastError) + stream.get_next = (*[0]byte)(C.SnowflakeArrayStreamGetNextTrampoline) + stream.get_schema = (*[0]byte)(C.SnowflakeArrayStreamGetSchemaTrampoline) + stream.release = (*[0]byte)(C.SnowflakeArrayStreamRelease) + hndl := cgo.NewHandle(cStream) + stream.private_data = createHandle(hndl) + rdr.Retain() +} + type cDatabase struct { opts map[string]string db adbc.Database } -//export SnowflakeDatabaseNew -func SnowflakeDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export SnowflakeDatabaseGetOption +func SnowflakeDatabaseGetOption(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseNew", e) + code = poison(err, "AdbcDatabaseGetOption", e) } }() - if atomic.LoadInt32(&globalPoison) != 0 { - setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") - return C.ADBC_STATUS_INTERNAL + cdb := checkDBInit(db, err, "AdbcDatabaseGetOption") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE } - if db.private_data != nil { - setErr(err, "AdbcDatabaseNew: database already allocated") + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export SnowflakeDatabaseGetOptionBytes +func SnowflakeDatabaseGetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionBytes") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - dbobj := &cDatabase{opts: make(map[string]string)} - hndl := cgo.NewHandle(dbobj) - db.private_data = createHandle(hndl) - return C.ADBC_STATUS_OK + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) } -//export SnowflakeDatabaseSetOption -func SnowflakeDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export SnowflakeDatabaseGetOptionDouble +func SnowflakeDatabaseGetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseSetOption", e) + code = poison(err, "AdbcDatabaseGetOptionDouble", e) } }() - if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionDouble") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - cdb := getFromHandle[cDatabase](db.private_data) - k, v := C.GoString(key), C.GoString(value) - cdb.opts[k] = v + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return C.ADBC_STATUS_OK + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export SnowflakeDatabaseGetOptionInt +func SnowflakeDatabaseGetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export SnowflakeDatabaseInit @@ -222,6 +517,27 @@ func SnowflakeDatabaseInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) ( return C.ADBC_STATUS_OK } +//export SnowflakeDatabaseNew +func SnowflakeDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseNew", e) + } + }() + if atomic.LoadInt32(&globalPoison) != 0 { + setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") + return C.ADBC_STATUS_INTERNAL + } + if db.private_data != nil { + setErr(err, "AdbcDatabaseNew: database already allocated") + return C.ADBC_STATUS_INVALID_STATE + } + dbobj := &cDatabase{opts: make(map[string]string)} + hndl := cgo.NewHandle(dbobj) + db.private_data = createHandle(hndl) + return C.ADBC_STATUS_OK +} + //export SnowflakeDatabaseRelease func SnowflakeDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -250,7 +566,99 @@ func SnowflakeDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError return C.ADBC_STATUS_OK } +//export SnowflakeDatabaseSetOption +func SnowflakeDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOption", e) + } + }() + if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + return C.ADBC_STATUS_INVALID_STATE + } + cdb := getFromHandle[cDatabase](db.private_data) + + k, v := C.GoString(key), C.GoString(value) + if cdb.db != nil { + opts, ok := cdb.db.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(k, v))) + } else { + cdb.opts[k] = v + } + + return C.ADBC_STATUS_OK +} + +//export SnowflakeDatabaseSetOptionBytes +func SnowflakeDatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionBytes") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export SnowflakeDatabaseSetOptionDouble +func SnowflakeDatabaseSetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionDouble", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionDouble") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export SnowflakeDatabaseSetOptionInt +func SnowflakeDatabaseSetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) +} + type cConn struct { + cancellableContext + cnxn adbc.Connection initArgs map[string]string } @@ -284,6 +692,102 @@ func checkConnInit(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname return conn } +//export SnowflakeConnectionGetOption +func SnowflakeConnectionGetOption(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOption", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOption") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export SnowflakeConnectionGetOptionBytes +func SnowflakeConnectionGetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export SnowflakeConnectionGetOptionDouble +func SnowflakeConnectionGetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export SnowflakeConnectionGetOptionInt +func SnowflakeConnectionGetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + //export SnowflakeConnectionNew func SnowflakeConnectionNew(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -327,13 +831,75 @@ func SnowflakeConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.cch return C.ADBC_STATUS_OK } - opts, ok := conn.cnxn.(adbc.PostInitOptions) + opts, ok := conn.cnxn.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcConnectionSetOption: not supported post-init") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))) +} + +//export SnowflakeConnectionSetOptionBytes +func SnowflakeConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export SnowflakeConnectionSetOptionDouble +func SnowflakeConnectionSetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export SnowflakeConnectionSetOptionInt +func SnowflakeConnectionSetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) if !ok { - setErr(err, "AdbcConnectionSetOption: not supported post-init") + setErr(err, "AdbcConnectionSetOptionInt: options are not supported") return C.ADBC_STATUS_NOT_IMPLEMENTED } - rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val))) - return C.AdbcStatusCode(rawCode) + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export SnowflakeConnectionInit @@ -396,8 +962,9 @@ func SnowflakeConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_Adb conn := h.Value().(*cConn) defer func() { + conn.cancelContext() conn.cnxn = nil - C.free(unsafe.Pointer(cnxn.private_data)) + C.free(cnxn.private_data) cnxn.private_data = nil h.Delete() // manually trigger GC for two reasons: @@ -434,26 +1001,19 @@ func toCdataArray(ptr *C.struct_ArrowArray) *cdata.CArrowArray { return (*cdata.CArrowArray)(unsafe.Pointer(ptr)) } -//export SnowflakeConnectionGetInfo -func SnowflakeConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.uint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export SnowflakeConnectionCancel +func SnowflakeConnectionCancel(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcConnectionGetInfo", e) + code = poison(err, "AdbcConnectionCancel", e) } }() - conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + conn := checkConnInit(cnxn, err, "AdbcConnectionCancel") if conn == nil { return C.ADBC_STATUS_INVALID_STATE } - infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) - rdr, e := conn.cnxn.GetInfo(context.Background(), infoCodes) - if e != nil { - return C.AdbcStatusCode(errToAdbcErr(err, e)) - } - - defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + conn.cancelContext() return C.ADBC_STATUS_OK } @@ -481,6 +1041,29 @@ func toStrSlice(in **C.cchar_t) []string { return out } +//export SnowflakeConnectionGetInfo +func SnowflakeConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetInfo", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) + rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + //export SnowflakeConnectionGetObjects func SnowflakeConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, catalog, dbSchema, tableName *C.cchar_t, tableType **C.cchar_t, columnName *C.cchar_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { @@ -494,12 +1077,67 @@ func SnowflakeConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, c return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetObjects(context.Background(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + rdr, e := conn.cnxn.GetObjects(conn.newContext(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export SnowflakeConnectionGetStatistics +func SnowflakeConnectionGetStatistics(cnxn *C.struct_AdbcConnection, catalog, dbSchema, tableName *C.cchar_t, approximate C.char, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatistics(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), int(approximate) != 0) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export SnowflakeConnectionGetStatisticNames +func SnowflakeConnectionGetStatisticNames(cnxn *C.struct_AdbcConnection, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatisticNames(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -515,7 +1153,7 @@ func SnowflakeConnectionGetTableSchema(cnxn *C.struct_AdbcConnection, catalog, d return C.ADBC_STATUS_INVALID_STATE } - sc, e := conn.cnxn.GetTableSchema(context.Background(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) + sc, e := conn.cnxn.GetTableSchema(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -535,12 +1173,12 @@ func SnowflakeConnectionGetTableTypes(cnxn *C.struct_AdbcConnection, out *C.stru return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetTableTypes(context.Background()) + rdr, e := conn.cnxn.GetTableTypes(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -556,12 +1194,12 @@ func SnowflakeConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialized return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.ReadPartition(context.Background(), fromCArr[byte](serialized, int(serializedLen))) + rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen))) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -577,7 +1215,7 @@ func SnowflakeConnectionCommit(cnxn *C.struct_AdbcConnection, err *C.struct_Adbc return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(conn.newContext()))) } //export SnowflakeConnectionRollback @@ -592,25 +1230,137 @@ func SnowflakeConnectionRollback(cnxn *C.struct_AdbcConnection, err *C.struct_Ad return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(conn.newContext()))) +} + +type cStmt struct { + cancellableContext + + stmt adbc.Statement } -func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) adbc.Statement { +func checkStmtAlloc(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) - return nil + return false } if stmt == nil { setErr(err, "%s: statement not allocated", fname) - return nil + return false } - if stmt.private_data == nil { - setErr(err, "%s: statement not initialized", fname) + setErr(err, "%s: statement not allocated", fname) + return false + } + return true +} + +func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) *cStmt { + if !checkStmtAlloc(stmt, err, fname) { + return nil + } + cStmt := getFromHandle[cStmt](stmt.private_data) + if cStmt.stmt == nil { + setErr(err, "%s: statement not allocated", fname) return nil } + return cStmt +} + +//export SnowflakeStatementGetOption +func SnowflakeStatementGetOption(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOption", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOption") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export SnowflakeStatementGetOptionBytes +func SnowflakeStatementGetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export SnowflakeStatementGetOptionDouble +func SnowflakeStatementGetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export SnowflakeStatementGetOptionInt +func SnowflakeStatementGetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return (*(*cgo.Handle)(stmt.private_data)).Value().(adbc.Statement) + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export SnowflakeStatementNew @@ -639,8 +1389,8 @@ func SnowflakeStatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcSta return C.AdbcStatusCode(errToAdbcErr(err, e)) } - h := cgo.NewHandle(st) - stmt.private_data = createHandle(h) + hndl := cgo.NewHandle(&cStmt{stmt: st}) + stmt.private_data = createHandle(hndl) return C.ADBC_STATUS_OK } @@ -655,31 +1405,46 @@ func SnowflakeStatementRelease(stmt *C.struct_AdbcStatement, err *C.struct_AdbcE setErr(err, "AdbcStatementRelease: Go panicked, driver is in unknown state") return C.ADBC_STATUS_INTERNAL } - if stmt == nil { - setErr(err, "AdbcStatementRelease: statement not allocated") + if !checkStmtAlloc(stmt, err, "AdbcStatementRelease") { return C.ADBC_STATUS_INVALID_STATE } + h := (*(*cgo.Handle)(stmt.private_data)) - if stmt.private_data == nil { - setErr(err, "AdbcStatementRelease: statement not initialized") - return C.ADBC_STATUS_INVALID_STATE + st := h.Value().(*cStmt) + defer func() { + st.cancelContext() + st.stmt = nil + C.free(stmt.private_data) + stmt.private_data = nil + h.Delete() + // manually trigger GC for two reasons: + // 1. ASAN expects the release callback to be called before + // the process ends, but GC is not deterministic. So by manually + // triggering the GC we ensure the release callback gets called. + // 2. Creates deterministic GC behavior by all Release functions + // triggering a garbage collection + runtime.GC() + }() + if st.stmt == nil { + return C.ADBC_STATUS_OK } + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Close())) +} - h := (*(*cgo.Handle)(stmt.private_data)) - st := h.Value().(adbc.Statement) - C.free(stmt.private_data) - stmt.private_data = nil +//export SnowflakeStatementCancel +func SnowflakeStatementCancel(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementCancel", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementCancel") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } - e := st.Close() - h.Delete() - // manually trigger GC for two reasons: - // 1. ASAN expects the release callback to be called before - // the process ends, but GC is not deterministic. So by manually - // triggering the GC we ensure the release callback gets called. - // 2. Creates deterministic GC behavior by all Release functions - // triggering a garbage collection - runtime.GC() - return C.AdbcStatusCode(errToAdbcErr(err, e)) + st.cancelContext() + return C.ADBC_STATUS_OK } //export SnowflakeStatementPrepare @@ -694,7 +1459,7 @@ func SnowflakeStatementPrepare(stmt *C.struct_AdbcStatement, err *C.struct_AdbcE return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.Prepare(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Prepare(st.newContext()))) } //export SnowflakeStatementExecuteQuery @@ -710,7 +1475,7 @@ func SnowflakeStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ } if out == nil { - n, e := st.ExecuteUpdate(context.Background()) + n, e := st.stmt.ExecuteUpdate(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -719,7 +1484,7 @@ func SnowflakeStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ *affected = C.int64_t(n) } } else { - rdr, n, e := st.ExecuteQuery(context.Background()) + rdr, n, e := st.stmt.ExecuteQuery(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -729,8 +1494,35 @@ func SnowflakeStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) + } + return C.ADBC_STATUS_OK +} + +//export SnowflakeStatementExecuteSchema +func SnowflakeStatementExecuteSchema(stmt *C.struct_AdbcStatement, schema *C.struct_ArrowSchema, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementExecuteQuery", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementExecuteQuery") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + es, ok := st.stmt.(adbc.StatementExecuteSchema) + if !ok { + setErr(err, "AdbcStatementExecuteSchema: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + sc, e := es.ExecuteSchema(st.newContext()) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) } + + cdata.ExportArrowSchema(sc, toCdataSchema(schema)) return C.ADBC_STATUS_OK } @@ -746,7 +1538,7 @@ func SnowflakeStatementSetSqlQuery(stmt *C.struct_AdbcStatement, query *C.cchar_ return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSqlQuery(C.GoString(query)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSqlQuery(C.GoString(query)))) } //export SnowflakeStatementSetSubstraitPlan @@ -761,7 +1553,7 @@ func SnowflakeStatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C.cu return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) } //export SnowflakeStatementBind @@ -784,7 +1576,7 @@ func SnowflakeStatementBind(stmt *C.struct_AdbcStatement, values *C.struct_Arrow } defer rec.Release() - return C.AdbcStatusCode(errToAdbcErr(err, st.Bind(context.Background(), rec))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Bind(st.newContext(), rec))) } //export SnowflakeStatementBindStream @@ -803,7 +1595,7 @@ func SnowflakeStatementBindStream(stmt *C.struct_AdbcStatement, stream *C.struct if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } - return C.AdbcStatusCode(errToAdbcErr(err, st.BindStream(context.Background(), rdr.(array.RecordReader)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.BindStream(st.newContext(), rdr.(array.RecordReader)))) } //export SnowflakeStatementGetParameterSchema @@ -818,7 +1610,7 @@ func SnowflakeStatementGetParameterSchema(stmt *C.struct_AdbcStatement, schema * return C.ADBC_STATUS_INVALID_STATE } - sc, e := st.GetParameterSchema() + sc, e := st.stmt.GetParameterSchema() if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -839,7 +1631,70 @@ func SnowflakeStatementSetOption(stmt *C.struct_AdbcStatement, key, value *C.cch return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetOption(C.GoString(key), C.GoString(value)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetOption(C.GoString(key), C.GoString(value)))) +} + +//export SnowflakeStatementSetOptionBytes +func SnowflakeStatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export SnowflakeStatementSetOptionDouble +func SnowflakeStatementSetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export SnowflakeStatementSetOptionInt +func SnowflakeStatementSetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export releasePartitions @@ -868,7 +1723,7 @@ func SnowflakeStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema *C return C.ADBC_STATUS_INVALID_STATE } - sc, part, n, e := st.ExecutePartitions(context.Background()) + sc, part, n, e := st.stmt.ExecutePartitions(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -913,13 +1768,20 @@ func SnowflakeStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema *C //export SnowflakeDriverInit func SnowflakeDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcError) C.AdbcStatusCode { - if version != C.ADBC_VERSION_1_0_0 { - setErr(err, "Only version %d supported, got %d", int(C.ADBC_VERSION_1_0_0), int(version)) + driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) + + switch version { + case C.ADBC_VERSION_1_0_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE) + memory.Set(sink, 0) + case C.ADBC_VERSION_1_1_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE) + memory.Set(sink, 0) + default: + setErr(err, "Only version 1.0.0/1.1.0 supported, got %d", int(version)) return C.ADBC_STATUS_NOT_IMPLEMENTED } - driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) - C.memset(unsafe.Pointer(driver), 0, C.sizeof_struct_AdbcDriver) driver.DatabaseInit = (*[0]byte)(C.SnowflakeDatabaseInit) driver.DatabaseNew = (*[0]byte)(C.SnowflakeDatabaseNew) driver.DatabaseRelease = (*[0]byte)(C.SnowflakeDatabaseRelease) @@ -949,6 +1811,41 @@ func SnowflakeDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcErr driver.StatementGetParameterSchema = (*[0]byte)(C.SnowflakeStatementGetParameterSchema) driver.StatementPrepare = (*[0]byte)(C.SnowflakeStatementPrepare) + if version == C.ADBC_VERSION_1_1_0 { + driver.ErrorGetDetailCount = (*[0]byte)(C.SnowflakeErrorGetDetailCount) + driver.ErrorGetDetail = (*[0]byte)(C.SnowflakeErrorGetDetail) + driver.ErrorFromArrayStream = (*[0]byte)(C.SnowflakeErrorFromArrayStream) + + driver.DatabaseGetOption = (*[0]byte)(C.SnowflakeDatabaseGetOption) + driver.DatabaseGetOptionBytes = (*[0]byte)(C.SnowflakeDatabaseGetOptionBytes) + driver.DatabaseGetOptionDouble = (*[0]byte)(C.SnowflakeDatabaseGetOptionDouble) + driver.DatabaseGetOptionInt = (*[0]byte)(C.SnowflakeDatabaseGetOptionInt) + driver.DatabaseSetOptionBytes = (*[0]byte)(C.SnowflakeDatabaseSetOptionBytes) + driver.DatabaseSetOptionDouble = (*[0]byte)(C.SnowflakeDatabaseSetOptionDouble) + driver.DatabaseSetOptionInt = (*[0]byte)(C.SnowflakeDatabaseSetOptionInt) + + driver.ConnectionCancel = (*[0]byte)(C.SnowflakeConnectionCancel) + driver.ConnectionGetOption = (*[0]byte)(C.SnowflakeConnectionGetOption) + driver.ConnectionGetOptionBytes = (*[0]byte)(C.SnowflakeConnectionGetOptionBytes) + driver.ConnectionGetOptionDouble = (*[0]byte)(C.SnowflakeConnectionGetOptionDouble) + driver.ConnectionGetOptionInt = (*[0]byte)(C.SnowflakeConnectionGetOptionInt) + driver.ConnectionGetStatistics = (*[0]byte)(C.SnowflakeConnectionGetStatistics) + driver.ConnectionGetStatisticNames = (*[0]byte)(C.SnowflakeConnectionGetStatisticNames) + driver.ConnectionSetOptionBytes = (*[0]byte)(C.SnowflakeConnectionSetOptionBytes) + driver.ConnectionSetOptionDouble = (*[0]byte)(C.SnowflakeConnectionSetOptionDouble) + driver.ConnectionSetOptionInt = (*[0]byte)(C.SnowflakeConnectionSetOptionInt) + + driver.StatementCancel = (*[0]byte)(C.SnowflakeStatementCancel) + driver.StatementExecuteSchema = (*[0]byte)(C.SnowflakeStatementExecuteSchema) + driver.StatementGetOption = (*[0]byte)(C.SnowflakeStatementGetOption) + driver.StatementGetOptionBytes = (*[0]byte)(C.SnowflakeStatementGetOptionBytes) + driver.StatementGetOptionDouble = (*[0]byte)(C.SnowflakeStatementGetOptionDouble) + driver.StatementGetOptionInt = (*[0]byte)(C.SnowflakeStatementGetOptionInt) + driver.StatementSetOptionBytes = (*[0]byte)(C.SnowflakeStatementSetOptionBytes) + driver.StatementSetOptionDouble = (*[0]byte)(C.SnowflakeStatementSetOptionDouble) + driver.StatementSetOptionInt = (*[0]byte)(C.SnowflakeStatementSetOptionInt) + } + return C.ADBC_STATUS_OK } diff --git a/go/adbc/pkg/snowflake/utils.c b/go/adbc/pkg/snowflake/utils.c index 24d3ca3d90..a2bc39ebf1 100644 --- a/go/adbc/pkg/snowflake/utils.c +++ b/go/adbc/pkg/snowflake/utils.c @@ -35,52 +35,142 @@ void Snowflake_release_error(struct AdbcError* error) { error->release = NULL; } -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return SnowflakeDatabaseNew(database, error); +void SnowflakeReleaseErrWithDetails(struct AdbcError* error) { + if (!error || error->release != SnowflakeReleaseErrWithDetails || + !error->private_data) { + return; + } + + struct SnowflakeError* details = + (struct SnowflakeError*) error->private_data; + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + free(details->keys); + free(details->values); + free(details->lengths); + free(details); + + free(error->message); + error->message = NULL; + error->release = NULL; + error->private_data = NULL; } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return SnowflakeDatabaseSetOption(database, key, value, error); +int SnowflakeErrorGetDetailCount(const struct AdbcError* error) { + if (!error || error->release != SnowflakeReleaseErrWithDetails || + !error->private_data) { + return 0; + } + + return ((struct SnowflakeError*) error->private_data)->count; +} + +struct AdbcErrorDetail SnowflakeErrorGetDetail(const struct AdbcError* error, + int index) { + if (!error || error->release != SnowflakeReleaseErrWithDetails || + !error->private_data) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct SnowflakeError* details = (struct SnowflakeError*) error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index] + }; +} + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return SnowflakeErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return SnowflakeErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return SnowflakeErrorFromArrayStream(stream, status); +} + +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SnowflakeDatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return SnowflakeDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return SnowflakeDatabaseGetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return SnowflakeDatabaseGetOptionInt(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return SnowflakeDatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return SnowflakeDatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return SnowflakeDatabaseRelease(database, error); } -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, - struct AdbcError* error) { - return SnowflakeConnectionNew(connection, error); +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return SnowflakeDatabaseSetOption(database, key, value, error); } -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, - const char* value, struct AdbcError* error) { - return SnowflakeConnectionSetOption(connection, key, value, error); +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return SnowflakeDatabaseSetOptionBytes(database, key, value, length, error); } -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, - struct AdbcDatabase* database, - struct AdbcError* error) { - return SnowflakeConnectionInit(connection, database, error); +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return SnowflakeDatabaseSetOptionDouble(database, key, value, error); } -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, - struct AdbcError* error) { - return SnowflakeConnectionRelease(connection, error); +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return SnowflakeDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return SnowflakeConnectionCancel(connection, error); +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + return SnowflakeConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); - return SnowflakeConnectionGetInfo(connection, info_codes, info_codes_length, out, - error); + return SnowflakeConnectionGetInfo(connection, info_codes, info_codes_length, + out, error); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -91,7 +181,46 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return SnowflakeConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_type, column_name, out, error); + table_type, column_name, out, error); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SnowflakeConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return SnowflakeConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return SnowflakeConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return SnowflakeConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return SnowflakeConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return SnowflakeConnectionGetStatisticNames(connection, out, error); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -101,7 +230,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, struct AdbcError* error) { if (schema) memset(schema, 0, sizeof(*schema)); return SnowflakeConnectionGetTableSchema(connection, catalog, db_schema, table_name, - schema, error); + schema, error); } AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, @@ -111,6 +240,17 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, return SnowflakeConnectionGetTableTypes(connection, out, error); } +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return SnowflakeConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + return SnowflakeConnectionNew(connection, error); +} + AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, const uint8_t* serialized_partition, size_t serialized_length, @@ -118,12 +258,12 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return SnowflakeConnectionReadPartition(connection, serialized_partition, - serialized_length, out, error); + serialized_length, out, error); } -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, - struct AdbcError* error) { - return SnowflakeConnectionCommit(connection, error); +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + return SnowflakeConnectionRelease(connection, error); } AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, @@ -131,39 +271,32 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return SnowflakeConnectionRollback(connection, error); } -AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, - struct AdbcStatement* statement, - struct AdbcError* error) { - return SnowflakeStatementNew(connection, statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, - struct AdbcError* error) { - return SnowflakeStatementRelease(statement, error); +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error) { + return SnowflakeConnectionSetOption(connection, key, value, error); } -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, - struct ArrowArrayStream* out, - int64_t* rows_affected, - struct AdbcError* error) { - if (out) memset(out, 0, sizeof(*out)); - return SnowflakeStatementExecuteQuery(statement, out, rows_affected, error); +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return SnowflakeConnectionSetOptionBytes(connection, key, value, length, error); } -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, - struct AdbcError* error) { - return SnowflakeStatementPrepare(statement, error); +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return SnowflakeConnectionSetOptionDouble(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, - const char* query, struct AdbcError* error) { - return SnowflakeStatementSetSqlQuery(statement, query, error); +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return SnowflakeConnectionSetOptionInt(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, - const uint8_t* plan, size_t length, - struct AdbcError* error) { - return SnowflakeStatementSetSubstraitPlan(statement, plan, length, error); +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return SnowflakeStatementCancel(statement, error); } AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, @@ -178,6 +311,56 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return SnowflakeStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + if (partitions) memset(partitions, 0, sizeof(*partitions)); + return SnowflakeStatementExecutePartitions(statement, schema, partitions, + rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + if (out) memset(out, 0, sizeof(*out)); + return SnowflakeStatementExecuteQuery(statement, out, rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + return SnowflakeStatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SnowflakeStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return SnowflakeStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return SnowflakeStatementGetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + return SnowflakeStatementGetOptionInt(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -185,20 +368,54 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, return SnowflakeStatementGetParameterSchema(statement, schema, error); } +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + return SnowflakeStatementNew(connection, statement, error); +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + return SnowflakeStatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + return SnowflakeStatementRelease(statement, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + return SnowflakeStatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + return SnowflakeStatementSetSubstraitPlan(statement, plan, length, error); +} + AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error) { return SnowflakeStatementSetOption(statement, key, value, error); } -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, - struct ArrowSchema* schema, - struct AdbcPartitions* partitions, - int64_t* rows_affected, - struct AdbcError* error) { - if (schema) memset(schema, 0, sizeof(*schema)); - if (partitions) memset(partitions, 0, sizeof(*partitions)); - return SnowflakeStatementExecutePartitions(statement, schema, partitions, rows_affected, - error); +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return SnowflakeStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return SnowflakeStatementSetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + return SnowflakeStatementSetOptionInt(statement, key, value, error); } ADBC_EXPORT @@ -206,6 +423,23 @@ AdbcStatusCode AdbcDriverInit(int version, void* driver, struct AdbcError* error return SnowflakeDriverInit(version, driver, error); } +int SnowflakeArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +int SnowflakeArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); + +int SnowflakeArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return SnowflakeArrayStreamGetSchema(stream, out); +} + +int SnowflakeArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return SnowflakeArrayStreamGetNext(stream, out); +} + #ifdef __cplusplus } #endif diff --git a/go/adbc/pkg/snowflake/utils.h b/go/adbc/pkg/snowflake/utils.h index 453d7a099a..23391dfd70 100644 --- a/go/adbc/pkg/snowflake/utils.h +++ b/go/adbc/pkg/snowflake/utils.h @@ -26,72 +26,151 @@ #include #include "../../drivermgr/adbc.h" +struct AdbcError* SnowflakeErrorFromArrayStream(struct ArrowArrayStream*, + AdbcStatusCode*); +AdbcStatusCode SnowflakeDatabaseGetOption(struct AdbcDatabase*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseGetOptionBytes(struct AdbcDatabase*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseGetOptionDouble(struct AdbcDatabase*, const char*, + double*, struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseGetOptionInt(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode SnowflakeDatabaseNew(struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode SnowflakeDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode SnowflakeDatabaseSetOption(struct AdbcDatabase* db, const char* key, const char* value, struct AdbcError* err); -AdbcStatusCode SnowflakeDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode SnowflakeDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionSetOption(struct AdbcConnection* cnxn, const char* key, - const char* val, struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionInit(struct AdbcConnection* cnxn, - struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionRelease(struct AdbcConnection* cnxn, - struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionGetInfo(struct AdbcConnection* cnxn, uint32_t* codes, - size_t len, struct ArrowArrayStream* out, +AdbcStatusCode SnowflakeDatabaseSetOptionBytes(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseSetOptionDouble(struct AdbcDatabase*, const char*, double, + struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseSetOptionInt(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + +AdbcStatusCode SnowflakeConnectionCancel(struct AdbcConnection*, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionCommit(struct AdbcConnection* cnxn, + struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionGetInfo(struct AdbcConnection* cnxn, + const uint32_t* codes, size_t len, + struct ArrowArrayStream* out, struct AdbcError* err); AdbcStatusCode SnowflakeConnectionGetObjects( struct AdbcConnection* cnxn, int depth, const char* catalog, const char* dbSchema, const char* tableName, const char** tableType, const char* columnName, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionGetOption(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionGetOptionBytes(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionGetOptionDouble(struct AdbcConnection*, const char*, + double*, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionGetOptionInt(struct AdbcConnection*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionGetStatistics(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, + struct AdbcError*); +AdbcStatusCode SnowflakeConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); AdbcStatusCode SnowflakeConnectionGetTableSchema( struct AdbcConnection* cnxn, const char* catalog, const char* dbSchema, const char* tableName, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode SnowflakeConnectionGetTableTypes(struct AdbcConnection* cnxn, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionInit(struct AdbcConnection* cnxn, + struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); AdbcStatusCode SnowflakeConnectionReadPartition(struct AdbcConnection* cnxn, const uint8_t* serialized, size_t serializedLen, struct ArrowArrayStream* out, struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionCommit(struct AdbcConnection* cnxn, - struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionRelease(struct AdbcConnection* cnxn, + struct AdbcError* err); AdbcStatusCode SnowflakeConnectionRollback(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode SnowflakeStatementNew(struct AdbcConnection* cnxn, - struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode SnowflakeStatementRelease(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode SnowflakeStatementPrepare(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode SnowflakeStatementExecuteQuery(struct AdbcStatement* stmt, - struct ArrowArrayStream* out, - int64_t* affected, struct AdbcError* err); -AdbcStatusCode SnowflakeStatementSetSqlQuery(struct AdbcStatement* stmt, - const char* query, struct AdbcError* err); -AdbcStatusCode SnowflakeStatementSetSubstraitPlan(struct AdbcStatement* stmt, - const uint8_t* plan, size_t length, - struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionSetOption(struct AdbcConnection* cnxn, const char* key, + const char* val, struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode SnowflakeConnectionSetOptionDouble(struct AdbcConnection*, const char*, + double, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionSetOptionInt(struct AdbcConnection*, const char*, + int64_t, struct AdbcError*); + AdbcStatusCode SnowflakeStatementBind(struct AdbcStatement* stmt, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode SnowflakeStatementBindStream(struct AdbcStatement* stmt, struct ArrowArrayStream* stream, struct AdbcError* err); -AdbcStatusCode SnowflakeStatementGetParameterSchema(struct AdbcStatement* stmt, - struct ArrowSchema* schema, - struct AdbcError* err); -AdbcStatusCode SnowflakeStatementSetOption(struct AdbcStatement* stmt, const char* key, - const char* value, struct AdbcError* err); +AdbcStatusCode SnowflakeStatementCancel(struct AdbcStatement*, struct AdbcError*); +AdbcStatusCode SnowflakeStatementExecuteQuery(struct AdbcStatement* stmt, + struct ArrowArrayStream* out, + int64_t* affected, struct AdbcError* err); AdbcStatusCode SnowflakeStatementExecutePartitions(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcPartitions* partitions, int64_t* affected, struct AdbcError* err); +AdbcStatusCode SnowflakeStatementExecuteSchema(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcError*); +AdbcStatusCode SnowflakeStatementGetOption(struct AdbcStatement*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeStatementGetOptionBytes(struct AdbcStatement*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeStatementGetOptionDouble(struct AdbcStatement*, const char*, + double*, struct AdbcError*); +AdbcStatusCode SnowflakeStatementGetOptionInt(struct AdbcStatement*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode SnowflakeStatementGetParameterSchema(struct AdbcStatement* stmt, + struct ArrowSchema* schema, + struct AdbcError* err); +AdbcStatusCode SnowflakeStatementNew(struct AdbcConnection* cnxn, + struct AdbcStatement* stmt, struct AdbcError* err); +AdbcStatusCode SnowflakeStatementPrepare(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode SnowflakeStatementRelease(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode SnowflakeStatementSetOption(struct AdbcStatement* stmt, const char* key, + const char* value, struct AdbcError* err); +AdbcStatusCode SnowflakeStatementSetOptionBytes(struct AdbcStatement*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode SnowflakeStatementSetOptionDouble(struct AdbcStatement*, const char*, + double, struct AdbcError*); +AdbcStatusCode SnowflakeStatementSetOptionInt(struct AdbcStatement*, const char*, int64_t, + struct AdbcError*); +AdbcStatusCode SnowflakeStatementSetSqlQuery(struct AdbcStatement* stmt, + const char* query, struct AdbcError* err); +AdbcStatusCode SnowflakeStatementSetSubstraitPlan(struct AdbcStatement* stmt, + const uint8_t* plan, size_t length, + struct AdbcError* err); + AdbcStatusCode SnowflakeDriverInit(int version, void* rawDriver, struct AdbcError* err); static inline void SnowflakeerrRelease(struct AdbcError* error) { error->release(error); } void Snowflake_release_error(struct AdbcError* error); + +struct SnowflakeError { + char* message; + char** keys; + uint8_t** values; + size_t* lengths; + int count; +}; + +void SnowflakeReleaseErrWithDetails(struct AdbcError* error); + +int SnowflakeErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail SnowflakeErrorGetDetail(const struct AdbcError* error, int index); + +int SnowflakeArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out); +int SnowflakeArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out); diff --git a/go/adbc/standard_schemas.go b/go/adbc/standard_schemas.go index b5ca7d42b5..5ad1ae8ba0 100644 --- a/go/adbc/standard_schemas.go +++ b/go/adbc/standard_schemas.go @@ -92,6 +92,34 @@ var ( {Name: "catalog_db_schemas", Type: arrow.ListOf(DBSchemaSchema), Nullable: true}, }, nil) + StatisticsSchema = arrow.StructOf( + arrow.Field{Name: "table_name", Type: arrow.BinaryTypes.String, Nullable: false}, + arrow.Field{Name: "column_name", Type: arrow.BinaryTypes.String, Nullable: true}, + arrow.Field{Name: "statistic_key", Type: arrow.PrimitiveTypes.Int16, Nullable: false}, + arrow.Field{Name: "statistic_value", Type: arrow.DenseUnionOf([]arrow.Field{ + {Name: "int64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "uint64", Type: arrow.PrimitiveTypes.Uint64, Nullable: true}, + {Name: "float64", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "binary", Type: arrow.BinaryTypes.Binary, Nullable: true}, + }, []arrow.UnionTypeCode{0, 1, 2, 3}), Nullable: false}, + arrow.Field{Name: "statistic_is_approximate", Type: arrow.FixedWidthTypes.Boolean, Nullable: false}, + ) + + StatisticsDBSchemaSchema = arrow.StructOf( + arrow.Field{Name: "db_schema_name", Type: arrow.BinaryTypes.String, Nullable: true}, + arrow.Field{Name: "db_schema_statistics", Type: arrow.ListOf(StatisticsSchema), Nullable: false}, + ) + + GetStatisticsSchema = arrow.NewSchema([]arrow.Field{ + {Name: "catalog_name", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "catalog_db_schemas", Type: arrow.ListOf(StatisticsDBSchemaSchema), Nullable: false}, + }, nil) + + GetStatisticNamesSchema = arrow.NewSchema([]arrow.Field{ + {Name: "statistic_name", Type: arrow.BinaryTypes.String, Nullable: false}, + {Name: "statistic_key", Type: arrow.PrimitiveTypes.Int16, Nullable: false}, + }, nil) + GetTableSchemaSchema = arrow.NewSchema([]arrow.Field{ {Name: "catalog_name", Type: arrow.BinaryTypes.String, Nullable: true}, {Name: "db_schema_name", Type: arrow.BinaryTypes.String, Nullable: true}, diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go index ffc9e93dc8..7f9832ce3d 100644 --- a/go/adbc/validation/validation.go +++ b/go/adbc/validation/validation.go @@ -44,10 +44,20 @@ type DriverQuirks interface { DatabaseOptions() map[string]string // Return the SQL to reference the bind parameter for a given index BindParameter(index int) string + // Whether the driver supports bulk ingest + SupportsBulkIngest(mode string) bool // Whether two statements can be used at the same time on a single connection SupportsConcurrentStatements() bool + // Whether current catalog/schema are supported + SupportsCurrentCatalogSchema() bool + // Whether GetSetOptions is supported + SupportsGetSetOptions() bool + // Whether AdbcStatementExecuteSchema should work + SupportsExecuteSchema() bool // Whether AdbcStatementExecutePartitions should work SupportsPartitionedData() bool + // Whether statistics are supported + SupportsStatistics() bool // Whether transactions are supported (Commit/Rollback on connection) SupportsTransactions() bool // Whether retrieving the schema of prepared statement params is supported @@ -60,11 +70,10 @@ type DriverQuirks interface { CreateSampleTable(tableName string, r arrow.Record) error // Field Metadata for Sample Table for comparison SampleTableSchemaMetadata(tblName string, dt arrow.DataType) arrow.Metadata - // Whether the driver supports bulk ingest - SupportsBulkIngest() bool // have the driver drop a table with the correct SQL syntax DropTable(adbc.Connection, string) error + Catalog() string DBSchema() string Alloc() memory.Allocator @@ -115,6 +124,30 @@ func (c *ConnectionTests) TearDownTest() { c.DB = nil } +func (c *ConnectionTests) TestGetSetOptions() { + cnxn, err := c.DB.Open(context.Background()) + c.NoError(err) + c.NotNil(cnxn) + + stmt, err := cnxn.NewStatement() + c.NoError(err) + c.NotNil(stmt) + + expected := c.Quirks.SupportsGetSetOptions() + + _, ok := c.DB.(adbc.GetSetOptions) + c.Equal(expected, ok) + + _, ok = cnxn.(adbc.GetSetOptions) + c.Equal(expected, ok) + + _, ok = stmt.(adbc.GetSetOptions) + c.Equal(expected, ok) + + c.NoError(stmt.Close()) + c.NoError(cnxn.Close()) +} + func (c *ConnectionTests) TestNewConn() { cnxn, err := c.DB.Open(context.Background()) c.NoError(err) @@ -152,6 +185,12 @@ func (c *ConnectionTests) TestAutocommitDefault() { cnxn, _ := c.DB.Open(ctx) defer cnxn.Close() + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueEnabled, value) + } + expectedCode := adbc.StatusInvalidState var adbcError adbc.Error err := cnxn.Commit(ctx) @@ -188,8 +227,60 @@ func (c *ConnectionTests) TestAutocommitToggle() { c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueEnabled)) c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled)) + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueDisabled, value) + } + // it is ok to disable autocommit when it isn't enabled c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled)) + + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueDisabled, value) + } +} + +func (c *ConnectionTests) TestMetadataCurrentCatalog() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + getset, ok := cnxn.(adbc.GetSetOptions) + + if !c.Quirks.SupportsGetSetOptions() { + c.False(ok) + return + } + c.True(ok) + value, err := getset.GetOption(adbc.OptionKeyCurrentCatalog) + if c.Quirks.SupportsCurrentCatalogSchema() { + c.NoError(err) + c.Equal(c.Quirks.Catalog(), value) + } else { + c.Error(err) + } +} + +func (c *ConnectionTests) TestMetadataCurrentDbSchema() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + getset, ok := cnxn.(adbc.GetSetOptions) + + if !c.Quirks.SupportsGetSetOptions() { + c.False(ok) + return + } + c.True(ok) + value, err := getset.GetOption(adbc.OptionKeyCurrentDbSchema) + if c.Quirks.SupportsCurrentCatalogSchema() { + c.NoError(err) + c.Equal(c.Quirks.DBSchema(), value) + } else { + c.Error(err) + } } func (c *ConnectionTests) TestMetadataGetInfo() { @@ -201,6 +292,7 @@ func (c *ConnectionTests) TestMetadataGetInfo() { adbc.InfoDriverName, adbc.InfoDriverVersion, adbc.InfoDriverArrowVersion, + adbc.InfoDriverADBCVersion, adbc.InfoVendorName, adbc.InfoVendorVersion, adbc.InfoVendorArrowVersion, @@ -219,19 +311,55 @@ func (c *ConnectionTests) TestMetadataGetInfo() { valUnion := rec.Column(1).(*array.DenseUnion) for i := 0; i < int(rec.NumRows()); i++ { code := codeCol.Value(i) - child := valUnion.Field(valUnion.ChildID(i)) - if child.IsNull(i) { + offset := int(valUnion.ValueOffset(i)) + valUnion.GetOneForMarshal(i) + if child.IsNull(offset) { exp := c.Quirks.GetMetadata(adbc.InfoCode(code)) c.Nilf(exp, "got nil for info %s, expected: %s", adbc.InfoCode(code), exp) } else { - // currently we only define utf8 values for metadata - c.Equal(c.Quirks.GetMetadata(adbc.InfoCode(code)), child.(*array.String).Value(i), adbc.InfoCode(code).String()) + expected := c.Quirks.GetMetadata(adbc.InfoCode(code)) + var actual interface{} + + switch valUnion.ChildID(i) { + case 0: + // String + actual = child.(*array.String).Value(offset) + case 2: + // int64 + actual = child.(*array.Int64).Value(offset) + default: + c.FailNow("Unknown union type code", valUnion.ChildID(i)) + } + + c.Equal(expected, actual, adbc.InfoCode(code).String()) } } } } +func (c *ConnectionTests) TestMetadataGetStatistics() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + + if c.Quirks.SupportsStatistics() { + stats, ok := cnxn.(adbc.ConnectionGetStatistics) + c.True(ok) + reader, err := stats.GetStatistics(ctx, nil, nil, nil, true) + c.NoError(err) + defer reader.Release() + } else { + stats, ok := cnxn.(adbc.ConnectionGetStatistics) + if ok { + _, err := stats.GetStatistics(ctx, nil, nil, nil, true) + var adbcErr adbc.Error + c.ErrorAs(err, &adbcErr) + c.Equal(adbc.StatusNotImplemented, adbcErr.Code) + } + } +} + func (c *ConnectionTests) TestMetadataGetTableSchema() { rec, _, err := array.RecordFromJSON(c.Quirks.Alloc(), arrow.NewSchema( []arrow.Field{ @@ -407,6 +535,49 @@ func (s *StatementTests) TestNewStatement() { s.Equal(adbc.StatusInvalidState, adbcError.Code) } +func (s *StatementTests) TestSqlExecuteSchema() { + if !s.Quirks.SupportsExecuteSchema() { + s.T().SkipNow() + } + + stmt, err := s.Cnxn.NewStatement() + s.Require().NoError(err) + defer stmt.Close() + + es, ok := stmt.(adbc.StatementExecuteSchema) + s.Require().True(ok, "%#v does not support ExecuteSchema", es) + + s.Run("no query", func() { + var adbcErr adbc.Error + + schema, err := es.ExecuteSchema(s.ctx) + s.ErrorAs(err, &adbcErr) + s.Equal(adbc.StatusInvalidState, adbcErr.Code) + s.Nil(schema) + }) + + s.Run("query", func() { + s.NoError(stmt.SetSqlQuery("SELECT 1, 'string'")) + + schema, err := es.ExecuteSchema(s.ctx) + s.NoError(err) + s.Equal(2, len(schema.Fields())) + s.True(schema.Field(0).Type.ID() == arrow.INT32 || schema.Field(0).Type.ID() == arrow.INT64) + s.Equal(arrow.STRING, schema.Field(1).Type.ID()) + }) + + s.Run("prepared", func() { + s.NoError(stmt.SetSqlQuery("SELECT 1, 'string'")) + s.NoError(stmt.Prepare(s.ctx)) + + schema, err := es.ExecuteSchema(s.ctx) + s.NoError(err) + s.Equal(2, len(schema.Fields())) + s.True(schema.Field(0).Type.ID() == arrow.INT32 || schema.Field(0).Type.ID() == arrow.INT64) + s.Equal(arrow.STRING, schema.Field(1).Type.ID()) + }) +} + func (s *StatementTests) TestSqlPartitionedInts() { stmt, err := s.Cnxn.NewStatement() s.Require().NoError(err) @@ -596,7 +767,7 @@ func (s *StatementTests) TestSqlPrepareErrorParamCountMismatch() { } func (s *StatementTests) TestSqlIngestInts() { - if !s.Quirks.SupportsBulkIngest() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreate) { s.T().SkipNow() } @@ -647,7 +818,7 @@ func (s *StatementTests) TestSqlIngestInts() { } func (s *StatementTests) TestSqlIngestAppend() { - if !s.Quirks.SupportsBulkIngest() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeAppend) { s.T().SkipNow() } @@ -683,6 +854,10 @@ func (s *StatementTests) TestSqlIngestAppend() { defer batch2.Release() s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeAppend) { + s.T().SkipNow() + } s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeAppend)) s.Require().NoError(stmt.Bind(s.ctx, batch2)) @@ -716,11 +891,151 @@ func (s *StatementTests) TestSqlIngestAppend() { s.Require().NoError(rdr.Err()) } +func (s *StatementTests) TestSqlIngestReplace() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeReplace) { + s.T().SkipNow() + } + + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) + + schema := arrow.NewSchema([]arrow.Field{{ + Name: "int64s", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) + + batchbldr := array.NewRecordBuilder(s.Quirks.Alloc(), schema) + defer batchbldr.Release() + bldr := batchbldr.Field(0).(*array.Int64Builder) + bldr.AppendValues([]int64{42}, []bool{true}) + batch := batchbldr.NewRecord() + defer batch.Release() + + // ingest and create table + stmt, err := s.Cnxn.NewStatement() + s.Require().NoError(err) + defer stmt.Close() + + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.Bind(s.ctx, batch)) + + affected, err := stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + // now replace + schema = arrow.NewSchema([]arrow.Field{{ + Name: "newintcol", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) + batchbldr2 := array.NewRecordBuilder(s.Quirks.Alloc(), schema) + defer batchbldr2.Release() + bldr2 := batchbldr2.Field(0).(*array.Int64Builder) + bldr2.AppendValues([]int64{42}, []bool{true}) + batch2 := batchbldr2.NewRecord() + defer batch2.Release() + + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeReplace)) + s.Require().NoError(stmt.Bind(s.ctx, batch2)) + + affected, err = stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest`)) + rdr, rows, err := stmt.ExecuteQuery(s.ctx) + s.Require().NoError(err) + if rows != -1 && rows != 1 { + s.FailNowf("invalid number of returned rows", "should be -1 or 1, got: %d", rows) + } + defer rdr.Release() + + s.Truef(schema.Equal(utils.RemoveSchemaMetadata(rdr.Schema())), "expected: %s\n got: %s", schema, rdr.Schema()) + s.Require().True(rdr.Next()) + rec := rdr.Record() + s.EqualValues(1, rec.NumRows()) + s.EqualValues(1, rec.NumCols()) + col, ok := rec.Column(0).(*array.Int64) + s.True(ok) + s.Equal(int64(42), col.Value(0)) + + s.Require().False(rdr.Next()) + s.Require().NoError(rdr.Err()) +} + +func (s *StatementTests) TestSqlIngestCreateAppend() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreateAppend) { + s.T().SkipNow() + } + + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) + + schema := arrow.NewSchema([]arrow.Field{{ + Name: "int64s", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) + + batchbldr := array.NewRecordBuilder(s.Quirks.Alloc(), schema) + defer batchbldr.Release() + bldr := batchbldr.Field(0).(*array.Int64Builder) + bldr.AppendValues([]int64{42}, []bool{true}) + batch := batchbldr.NewRecord() + defer batch.Release() + + // ingest and create table + stmt, err := s.Cnxn.NewStatement() + s.Require().NoError(err) + defer stmt.Close() + + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreateAppend)) + s.Require().NoError(stmt.Bind(s.ctx, batch)) + + affected, err := stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + // append + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreateAppend)) + s.Require().NoError(stmt.Bind(s.ctx, batch)) + + affected, err = stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + // validate + s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest`)) + rdr, rows, err := stmt.ExecuteQuery(s.ctx) + s.Require().NoError(err) + if rows != -1 && rows != 2 { + s.FailNowf("invalid number of returned rows", "should be -1 or 2, got: %d", rows) + } + defer rdr.Release() + + s.Truef(schema.Equal(utils.RemoveSchemaMetadata(rdr.Schema())), "expected: %s\n got: %s", schema, rdr.Schema()) + s.Require().True(rdr.Next()) + rec := rdr.Record() + s.EqualValues(2, rec.NumRows()) + s.EqualValues(1, rec.NumCols()) + col, ok := rec.Column(0).(*array.Int64) + s.True(ok) + s.Equal(int64(42), col.Value(0)) + s.Equal(int64(42), col.Value(1)) + + s.Require().False(rdr.Next()) + s.Require().NoError(rdr.Err()) +} + func (s *StatementTests) TestSqlIngestErrors() { - if !s.Quirks.SupportsBulkIngest() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreate) { s.T().SkipNow() } + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) + stmt, err := s.Cnxn.NewStatement() s.Require().NoError(err) defer stmt.Close() @@ -735,6 +1050,10 @@ func (s *StatementTests) TestSqlIngestErrors() { }) s.Run("append to nonexistent table", func() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeAppend) { + s.T().SkipNow() + } + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) schema := arrow.NewSchema([]arrow.Field{{ Name: "int64s", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) @@ -795,6 +1114,10 @@ func (s *StatementTests) TestSqlIngestErrors() { batch = batchbldr.NewRecord() defer batch.Release() + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreate) { + s.T().SkipNow() + } + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeAppend)) s.Require().NoError(stmt.Bind(s.ctx, batch)) diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java index fea705482a..c8e897eeee 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java @@ -27,7 +27,21 @@ *

Connections are not required to be thread-safe, but they can be used from multiple threads so * long as clients take care to serialize accesses to a connection. */ -public interface AdbcConnection extends AutoCloseable { +public interface AdbcConnection extends AutoCloseable, AdbcOptions { + /** + * Cancel execution of a query. + * + *

This can be used to interrupt execution of a method like {@link #getObjects(GetObjectsDepth, + * String, String, String, String[], String)}}. + * + *

This method must be thread-safe (other method are not necessarily thread-safe). + * + * @since ADBC API revision 1.1.0 + */ + default void cancel() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support cancel"); + } + /** Commit the pending transaction. */ default void commit() throws AdbcException { throw AdbcException.notImplemented("Connection does not support transactions"); @@ -102,7 +116,7 @@ default ArrowReader getInfo() throws AdbcException { * The definition of the GetObjects result schema. * * - * DB_SCHEMA_SCHEMA is a Struct with fields: + *

DB_SCHEMA_SCHEMA is a Struct with fields: * * * @@ -111,7 +125,7 @@ default ArrowReader getInfo() throws AdbcException { * *
Field Name Field Type
The definition of DB_SCHEMA_SCHEMA.
* - * TABLE_SCHEMA is a Struct with fields: + *

TABLE_SCHEMA is a Struct with fields: * * * @@ -122,7 +136,7 @@ default ArrowReader getInfo() throws AdbcException { * *
Field Name Field Type
The definition of TABLE_SCHEMA.
* - * COLUMN_SCHEMA is a Struct with fields: + *

COLUMN_SCHEMA is a Struct with fields: * * * @@ -148,7 +162,7 @@ default ArrowReader getInfo() throws AdbcException { * *
Field Name Field Type Comments
The definition of COLUMN_SCHEMA.
* - * Notes: + *

Notes: * *

    *
  1. The column's ordinal position in the table (starting from 1). @@ -157,7 +171,7 @@ default ArrowReader getInfo() throws AdbcException { * provide JDBC/ODBC-compatible metadata in an agnostic manner. *
* - * CONSTRAINT_SCHEMA is a Struct with fields: + *

CONSTRAINT_SCHEMA is a Struct with fields: * * * @@ -174,7 +188,7 @@ default ArrowReader getInfo() throws AdbcException { *
  • For FOREIGN KEY only, the referenced table and columns. * * - * USAGE_SCHEMA is a Struct with fields: + *

    USAGE_SCHEMA is a Struct with fields: * *

  • Field Name Field Type Comments
    * @@ -227,6 +241,94 @@ enum GetObjectsDepth { TABLES, } + /** + * Get statistics about the data distribution of table(s). + * + *

    The result is an Arrow dataset with the following schema: + * + *

    Field Name Field Type
    + * + * + * + * + *
    Field Name Field Type
    catalog_name utf8
    catalog_db_schemas list[DB_SCHEMA_SCHEMA] not null
    The definition of the GetStatistics result schema.
    + * + *

    DB_SCHEMA_SCHEMA is a Struct with fields: + * + * + * + * + * + * + *
    Field Name Field Type
    db_schema_name utf8
    db_schema_statistics list[STATISTICS_SCHEMA] not null
    The definition of DB_SCHEMA_SCHEMA.
    + * + *

    STATISTICS_SCHEMA is a Struct with fields: + * + * + * + * + * + * + * + * + * + *
    Field Name Field Type Comments
    table_name utf8 not null
    column_name utf8 (1)
    statistic_key int16 not null (2)
    statistic_value VALUE_SCHEMA not null
    statistic_is_approximatebool not null (3)
    The definition of STATISTICS_SCHEMA.
    + * + *

      + *
    1. If null, then the statistic applies to the entire table. + *
    2. A dictionary-encoded statistic name (although we do not use the Arrow dictionary type). + * Values in [0, 1024) are reserved for ADBC. Other values are for implementation-specific + * statistics. For the definitions of predefined statistic types, see {@link + * StandardStatistics}. To get driver-specific statistic names, use {@link + * #getStatisticNames()}. + *
    3. If true, then the value is approximate or best-effort. + *
    + * + *

    VALUE_SCHEMA is a dense union with members: + * + * + * + * + * + * + * + * + *
    Field Name Field Type
    int64 int64
    uint64 uint64
    float64 float64
    binary binary
    The definition of VALUE_SCHEMA.
    + * + * @param catalogPattern Only show tables in the given catalog. If null, do not filter by catalog. + * If an empty string, only show tables without a catalog. May be a search pattern (see class + * documentation). + * @param dbSchemaPattern Only show tables in the given database schema. If null, do not filter by + * database schema. If an empty string, only show tables without a database schema. May be a + * search pattern (see class documentation). + * @param tableNamePattern Only show tables with the given name. If an empty string, only show + * tables without a catalog. May be a search pattern (see class documentation). + * @param approximate If false, request exact values of statistics, else allow for best-effort, + * approximate, or cached values. The database may return approximate values regardless, as + * indicated in the result. Requesting exact values may be expensive or unsupported. + */ + default ArrowReader getStatistics( + String catalogPattern, String dbSchemaPattern, String tableNamePattern, boolean approximate) + throws AdbcException { + throw AdbcException.notImplemented("Connection does not support getStatistics()"); + } + + /** + * Get the names of additional statistics defined by this driver. + * + *

    The result is an Arrow dataset with the following schema: + * + * + * + * + * + * + *
    Field Name Field Type
    statistic_name utf8 not null
    statistic_key int16 not null
    The definition of the GetStatistics result schema.
    + */ + default ArrowReader getStatisticNames() throws AdbcException { + throw AdbcException.notImplemented("Connection does not support getStatisticNames()"); + } + /** * Get the Arrow schema of a database table. * @@ -285,6 +387,42 @@ default void setAutoCommit(boolean enableAutoCommit) throws AdbcException { throw AdbcException.notImplemented("Connection does not support transactions"); } + /** + * Get the current catalog. + * + * @since ADBC API revision 1.1.0 + */ + default String getCurrentCatalog() throws AdbcException { + throw AdbcException.notImplemented("Connection does not support current catalog"); + } + + /** + * Set the current catalog. + * + * @since ADBC API revision 1.1.0 + */ + default void setCurrentCatalog(String catalog) throws AdbcException { + throw AdbcException.notImplemented("Connection does not support current catalog"); + } + + /** + * Get the current schema. + * + * @since ADBC API revision 1.1.0 + */ + default String getCurrentDbSchema() throws AdbcException { + throw AdbcException.notImplemented("Connection does not support current catalog"); + } + + /** + * Set the current schema. + * + * @since ADBC API revision 1.1.0 + */ + default void setCurrentDbSchema(String dbSchema) throws AdbcException { + throw AdbcException.notImplemented("Connection does not support current catalog"); + } + /** * Get whether the connection is read-only. * diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDatabase.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDatabase.java index e63c598be9..723acfc0da 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDatabase.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDatabase.java @@ -24,7 +24,7 @@ * remote/networked databases, for in-memory databases, this object provides an explicit point of * ownership. */ -public interface AdbcDatabase extends AutoCloseable { +public interface AdbcDatabase extends AutoCloseable, AdbcOptions { /** Create a new connection to the database. */ AdbcConnection connect() throws AdbcException; } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java index 80abd18560..5e32fd1ed7 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java @@ -21,11 +21,42 @@ /** A handle to an ADBC database driver. */ public interface AdbcDriver { - /** The standard parameter name for a connection URL (type String). */ - String PARAM_URL = "adbc.url"; + /** + * The standard parameter name for a password (type String). + * + * @since ADBC API revision 1.1.0 + */ + TypedKey PARAM_PASSWORD = new TypedKey<>("password", String.class); + + /** + * The standard parameter name for a connection URI (type String). + * + * @since ADBC API revision 1.1.0 + */ + TypedKey PARAM_URI = new TypedKey<>("uri", String.class); + + /** + * The standard parameter name for a connection URL (type String). + * + * @deprecated Prefer {@link #PARAM_URI} instead. + */ + @Deprecated String PARAM_URL = "adbc.url"; + + /** + * The standard parameter name for a username (type String). + * + * @since ADBC API revision 1.1.0 + */ + TypedKey PARAM_USERNAME = new TypedKey<>("username", String.class); + /** The standard parameter name for SQL quirks configuration (type SqlQuirks). */ String PARAM_SQL_QUIRKS = "adbc.sql.quirks"; + /** ADBC API revision 1.0.0. */ + long ADBC_VERSION_1_0_0 = 1_000_000; + /** ADBC API revision 1.1.0. */ + long ADBC_VERSION_1_1_0 = 1_001_000; + /** * Open a database via this driver. * diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java index be5a4c6bc1..dce7570e3d 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java @@ -16,6 +16,9 @@ */ package org.apache.arrow.adbc.core; +import java.util.Collection; +import java.util.Collections; + /** * An error in the database or ADBC driver. * @@ -33,13 +36,25 @@ public class AdbcException extends Exception { private final AdbcStatusCode status; private final String sqlState; private final int vendorCode; + private Collection details; public AdbcException( String message, Throwable cause, AdbcStatusCode status, String sqlState, int vendorCode) { + this(message, cause, status, sqlState, vendorCode, Collections.emptyList()); + } + + public AdbcException( + String message, + Throwable cause, + AdbcStatusCode status, + String sqlState, + int vendorCode, + Collection details) { super(message, cause); this.status = status; this.sqlState = sqlState; this.vendorCode = vendorCode; + this.details = details; } /** Create a new exception with code {@link AdbcStatusCode#INVALID_ARGUMENT}. */ @@ -77,11 +92,30 @@ public int getVendorCode() { return vendorCode; } + /** + * Get extra driver-specific error details. + * + *

    This allows drivers to return custom, structured error information (for example, JSON or + * Protocol Buffers) that can be optionally parsed by clients, beyond the standard AdbcError + * fields, without having to encode it in the error message. The encoding of the data is + * driver-defined. + */ + public Collection getDetails() { + return details; + } + /** * Copy this exception with a different cause (a convenience for use with the static factories). */ public AdbcException withCause(Throwable cause) { - return new AdbcException(this.getMessage(), cause, status, sqlState, vendorCode); + return new AdbcException(getMessage(), cause, status, sqlState, vendorCode, details); + } + + /** + * Copy this exception with different details (a convenience for use with the static factories). + */ + public AdbcException withDetails(Collection details) { + return new AdbcException(getMessage(), getCause(), status, sqlState, vendorCode, details); } @Override @@ -98,6 +132,8 @@ public String toString() { + vendorCode + ", cause=" + getCause() + + ", details=" + + getDetails().size() + '}'; } } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java index 52c0956564..8d5c73ba9f 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java @@ -16,7 +16,12 @@ */ package org.apache.arrow.adbc.core; -/** Integer IDs used for requesting information about the database/driver. */ +/** + * Integer IDs used for requesting information about the database/driver. + * + *

    Since ADBC 1.1.0: the range [500, 1_000) is reserved for "XDBC" information, which is the same + * metadata provided by the same info code range in the Arrow Flight SQL GetSqlInfo RPC. + */ public enum AdbcInfoCode { /** The database vendor/product name (e.g. the server name) (type: utf8). */ VENDOR_NAME(0), @@ -31,6 +36,16 @@ public enum AdbcInfoCode { DRIVER_VERSION(101), /** The driver Arrow library version (type: utf8). */ DRIVER_ARROW_VERSION(102), + /** + * The ADBC API version (type: int64). + * + *

    The value should be one of the ADBC_VERSION constants. + * + * @see AdbcDriver#ADBC_VERSION_1_0_0 + * @see AdbcDriver#ADBC_VERSION_1_1_0 + * @since ADBC API revision 1.1.0 + */ + DRIVER_ADBC_VERSION(103), ; private final int value; diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java new file mode 100644 index 0000000000..5a8e78b08f --- /dev/null +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.core; + +/** An ADBC object that supports getting/setting generic options. */ +public interface AdbcOptions { + /** + * Get a generic option. + * + * @since ADBC API revision 1.1.0 + * @param key The option to retrieve. + * @return The option value. + * @param The option value type. + */ + default T getOption(TypedKey key) throws AdbcException { + throw AdbcException.notImplemented("Unsupported option " + key); + } + + /** + * Set a generic option. + * + * @since ADBC API revision 1.1.0 + * @param key The option to set. + * @param value The option value. + * @param The option value type. + */ + default void setOption(TypedKey key, T value) throws AdbcException { + throw AdbcException.notImplemented("Unsupported option " + key); + } +} diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java index ef2be487e2..a1f9e0f3b4 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Iterator; import java.util.List; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; @@ -40,8 +41,25 @@ *

    Statements are not required to be thread-safe, but they can be used from multiple threads so * long as clients take care to serialize accesses to a statement. */ -public interface AdbcStatement extends AutoCloseable { - /** Set a generic query option. */ +public interface AdbcStatement extends AutoCloseable, AdbcOptions { + /** + * Cancel execution of a query. + * + *

    This can be used to interrupt execution of a method like {@link #executeQuery()}. + * + *

    This method must be thread-safe (other method are not necessarily thread-safe). + * + * @since ADBC API revision 1.1.0 + */ + default void cancel() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support cancel"); + } + + /** + * Set a generic query option. + * + * @deprecated Prefer {@link #setOption(TypedKey, Object)}. + */ default void setOption(String key, Object value) throws AdbcException { throw AdbcException.notImplemented("Unsupported option " + key); } @@ -94,6 +112,46 @@ default PartitionResult executePartitioned() throws AdbcException { throw AdbcException.notImplemented("Statement does not support executePartitioned"); } + /** + * Get the schema of the result set without executing the query. + * + * @since ADBC API revision 1.1.0 + */ + default Schema executeSchema() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support executeSchema"); + } + + /** + * Execute a result set-generating query and get a list of partitions of the result set. + * + *

    These can be serialized and deserialized for parallel and/or distributed fetching. + * + *

    This may invalidate any prior result sets. + * + * @since ADBC API revision 1.1.0 + */ + default Iterator pollPartitioned() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support pollPartitioned"); + } + + /** + * Get the progress of executing a query. + * + * @since ADBC API revision 1.1.0 + */ + default double getProgress() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support getProgress"); + } + + /** + * Get the upper bound of the progress. + * + * @since ADBC API revision 1.1.0 + */ + default double getMaxProgress() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support getMaxProgress"); + } + /** * Get the schema for bound parameters. * diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/BulkIngestMode.java b/java/core/src/main/java/org/apache/arrow/adbc/core/BulkIngestMode.java index 2ab16ac428..e23e8de4ac 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/BulkIngestMode.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/BulkIngestMode.java @@ -24,7 +24,20 @@ public enum BulkIngestMode { /** * Do not create the table and append data; error if the table does not exist ({@link * AdbcStatusCode#NOT_FOUND}) or does not match the schema of the data to append ({@link - * AdbcStatusCode#ALREADY_EXISTS}). * + * AdbcStatusCode#ALREADY_EXISTS}). */ APPEND, + /** + * Create the table and insert data; drop the original table if it already exists. + * + * @since ADBC API revision 1.1.0 + */ + REPLACE, + /** + * Insert data; create the table if it does not exist, or error ({@link + * AdbcStatusCode#ALREADY_EXISTS}) if the table exists, but the schema does not match the schema + * of the data to append. + */ + CREATE_APPEND, + ; } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java b/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java new file mode 100644 index 0000000000..13521fb82e --- /dev/null +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.core; + +import java.util.Objects; + +/** Additional details (not necessarily human-readable) contained in an {@link AdbcException}. */ +public class ErrorDetail { + private final String key; + private final Object value; + + public ErrorDetail(String key, Object value) { + this.key = Objects.requireNonNull(key); + this.value = Objects.requireNonNull(value); + } + + public String getKey() { + return key; + } + + public Object getValue() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ErrorDetail that = (ErrorDetail) o; + return Objects.equals(getKey(), that.getKey()) && Objects.equals(getValue(), that.getValue()); + } + + @Override + public int hashCode() { + return Objects.hash(getKey(), getValue()); + } + + @Override + public String toString() { + return "ErrorDetail{" + "key='" + key + '\'' + ", value=" + value + '}'; + } +} diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java index a14c04c700..c059bb1b57 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java @@ -19,6 +19,8 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.UnionMode; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -30,10 +32,13 @@ private StandardSchemas() { throw new AssertionError("Do not instantiate this class"); } - private static final ArrowType INT16 = new ArrowType.Int(16, true); - private static final ArrowType INT32 = new ArrowType.Int(32, true); - private static final ArrowType INT64 = new ArrowType.Int(64, true); + private static final ArrowType INT16 = Types.MinorType.SMALLINT.getType(); + private static final ArrowType INT32 = Types.MinorType.INT.getType(); + private static final ArrowType INT64 = Types.MinorType.BIGINT.getType(); private static final ArrowType UINT32 = new ArrowType.Int(32, false); + private static final ArrowType UINT64 = new ArrowType.Int(64, false); + private static final ArrowType FLOAT64 = + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); /** The schema of the result set of {@link AdbcConnection#getInfo(int[])}}. */ public static final Schema GET_INFO_SCHEMA = @@ -83,11 +88,11 @@ private StandardSchemas() { Field.notNullable("constraint_type", ArrowType.Utf8.INSTANCE), new Field( "constraint_column_names", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList(Field.nullable("item", new ArrowType.Utf8()))), new Field( "constraint_column_usage", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), USAGE_SCHEMA)))); @@ -119,12 +124,12 @@ private StandardSchemas() { new Field("table_type", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), new Field( "table_columns", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), COLUMN_SCHEMA))), new Field( "table_constraints", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field( "item", FieldType.nullable(ArrowType.Struct.INSTANCE), CONSTRAINT_SCHEMA)))); @@ -134,20 +139,76 @@ private StandardSchemas() { new Field("db_schema_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), new Field( "db_schema_tables", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), TABLE_SCHEMA)))); + /** + * The schema of the result of {@link AdbcConnection#getObjects(AdbcConnection.GetObjectsDepth, + * String, String, String, String[], String)}. + */ public static final Schema GET_OBJECTS_SCHEMA = new Schema( Arrays.asList( new Field("catalog_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), new Field( "catalog_db_schemas", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field( "item", FieldType.nullable(ArrowType.Struct.INSTANCE), DB_SCHEMA_SCHEMA))))); + + public static final List STATISTICS_VALUE_SCHEMA = + Arrays.asList( + Field.nullable("int64", INT64), + Field.nullable("uint64", UINT64), + Field.nullable("float64", FLOAT64), + Field.nullable("binary", ArrowType.Binary.INSTANCE)); + + public static final List STATISTICS_SCHEMA = + Arrays.asList( + Field.notNullable("table_name", ArrowType.Utf8.INSTANCE), + Field.nullable("column_name", ArrowType.Utf8.INSTANCE), + Field.notNullable("statistic_key", INT16), + new Field( + "statistic_value", + FieldType.notNullable(new ArrowType.Union(UnionMode.Dense, new int[] {0, 1, 2, 3})), + STATISTICS_VALUE_SCHEMA), + Field.notNullable("statistic_is_approximate", ArrowType.Bool.INSTANCE)); + + public static final List STATISTICS_DB_SCHEMA_SCHEMA = + Arrays.asList( + new Field("db_schema_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), + new Field( + "db_schema_statistics", + FieldType.notNullable(ArrowType.List.INSTANCE), + Collections.singletonList( + new Field( + "item", FieldType.nullable(ArrowType.Struct.INSTANCE), STATISTICS_SCHEMA)))); + + /** + * The schema of the result of {@link AdbcConnection#getStatistics(String, String, String, + * boolean)}. + */ + public static final Schema GET_STATISTICS_SCHEMA = + new Schema( + Arrays.asList( + new Field("catalog_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), + new Field( + "catalog_db_schemas", + FieldType.notNullable(ArrowType.List.INSTANCE), + Collections.singletonList( + new Field( + "item", + FieldType.nullable(ArrowType.Struct.INSTANCE), + STATISTICS_DB_SCHEMA_SCHEMA))))); + + /** The schema of the result of {@link AdbcConnection#getStatisticNames()}. */ + public static final Schema GET_STATISTIC_NAMES_SCHEMA = + new Schema( + Arrays.asList( + Field.notNullable("statistic_name", ArrowType.Utf8.INSTANCE), + Field.notNullable("statistic_name", INT16))); } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java new file mode 100644 index 0000000000..f5097f4413 --- /dev/null +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.core; + +import java.util.Objects; + +/** + * Definitions of standard statistic names/keys. + * + *

    Statistic names are returned from {@link AdbcConnection#getStatistics(String, String, String, + * boolean)} in a dictionary-encoded form. This class provides the names and dictionary-encoded form + * of statistics defined by ADBC. + */ +public enum StandardStatistics { + /** + * The average byte width statistic. The average size in bytes of a row in the column. Value type + * is float64. + * + *

    For example, this is roughly the average length of a string for a string column. + */ + AVERAGE_BYTE_WIDTH("adbc.statistic.byte_width", (short) 0), + /** + * The distinct value count (NDV) statistic. The number of distinct values in the column. Value + * type is int64 (when not approximate) or float64 (when approximate). + */ + DISTINCT_COUNT("adbc.statistic.distinct_count", (short) 1), + /** + * The max byte width statistic. The maximum size in bytes of a row in the column. Value type is + * int64 (when not approximate) or float64 (when approximate). + * + *

    For example, this is the maximum length of a string for a string column. + */ + MAX_BYTE_WIDTH("adbc.statistic.byte_width", (short) 2), + /** The max value statistic. Value type is column-dependent. */ + MAX_VALUE("adbc.statistic.byte_width", (short) 3), + /** The min value statistic. Value type is column-dependent. */ + MIN_VALUE("adbc.statistic.byte_width", (short) 4), + /** + * The null count statistic. The number of values that are null in the column. Value type is int64 + * (when not approximate) or float64 (when approximate). + */ + NULL_COUNT("adbc.statistic.null_count", (short) 5), + /** + * The row count statistic. The number of rows in the column or table. Value type is int64 (when + * not approximate) or float64 (when approximate). + */ + ROW_COUNT("adbc.statistic.row_count", (short) 6), + ; + + private final String name; + private final short key; + + StandardStatistics(String name, short key) { + this.name = Objects.requireNonNull(name); + this.key = key; + } + + /** Get the statistic name. */ + public String getName() { + return name; + } + + /** Get the dictionary-encoded name. */ + public short getKey() { + return key; + } +} diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java b/java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java new file mode 100644 index 0000000000..21523bb429 --- /dev/null +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.core; + +import java.util.Map; +import java.util.Objects; + +/** + * A typesafe option key. + * + * @since ADBC API revision 1.1.0 + * @param The option value type. + */ +public final class TypedKey { + private final String key; + private final Class type; + + public TypedKey(String key, Class type) { + this.key = Objects.requireNonNull(key); + this.type = Objects.requireNonNull(type); + } + + /** Get the option key. */ + public String getKey() { + return key; + } + + /** + * Get the option value (if it was set) and check the type. + * + * @throws ClassCastException if the value is of the wrong type. + */ + public T get(Map options) { + Object value = options.get(key); + if (value == null) { + return null; + } + return type.cast(value); + } + + /** + * Set this option in an options map (like for {@link AdbcDriver#open(Map)}. + * + * @param options The options. + * @param value The option value. + */ + public void set(Map options, T value) { + options.put(key, value); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TypedKey that = (TypedKey) o; + return Objects.equals(key, that.key) && Objects.equals(type, that.type); + } + + @Override + public int hashCode() { + return Objects.hash(key, type); + } + + @Override + public String toString() { + return "AdbcOptionKey{" + key + ", " + type + '}'; + } +} diff --git a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java index 43a6df99c9..d3f79889ec 100644 --- a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java +++ b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java @@ -47,7 +47,7 @@ public AdbcDatabase initDatabase(BufferAllocator allocator) throws AdbcException String url = getFlightLocation(); final Map parameters = new HashMap<>(); - parameters.put(AdbcDriver.PARAM_URL, url); + AdbcDriver.PARAM_URI.set(parameters, url); return new FlightSqlDriver(allocator).open(parameters); } diff --git a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java index 306f69e44f..8a40714970 100644 --- a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java +++ b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java @@ -30,4 +30,16 @@ public static void beforeAll() { @Override @Disabled("Requires spec clarification") public void prepareQueryWithParameters() {} + + @Override + @Disabled("Not supported") + public void executeSchema() {} + + @Override + @Disabled("Not supported") + public void executeSchemaPrepared() {} + + @Override + @Disabled("Not supported") + public void executeSchemaParams() {} } diff --git a/java/driver/flight-sql/pom.xml b/java/driver/flight-sql/pom.xml index db7aef9900..0287c52d97 100644 --- a/java/driver/flight-sql/pom.xml +++ b/java/driver/flight-sql/pom.xml @@ -66,5 +66,17 @@ org.apache.arrow.adbc adbc-sql + + + + org.assertj + assertj-core + test + + + org.junit.jupiter + junit-jupiter + test + diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java index 30fc460b8e..5015ecfcfe 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java @@ -43,17 +43,22 @@ public class FlightSqlDriver implements AdbcDriver { @Override public AdbcDatabase open(Map parameters) throws AdbcException { - Object target = parameters.get("adbc.url"); - if (!(target instanceof String)) { - throw AdbcException.invalidArgument( - "[Flight SQL] Must provide String " + PARAM_URL + " parameter"); + String uri = PARAM_URI.get(parameters); + if (uri == null) { + Object target = parameters.get("adbc.url"); + if (!(target instanceof String)) { + throw AdbcException.invalidArgument( + "[Flight SQL] Must provide String " + PARAM_URI + " parameter"); + } + uri = (String) target; } + Location location; try { - location = new Location((String) target); + location = new Location(uri); } catch (URISyntaxException e) { throw AdbcException.invalidArgument( - String.format("[Flight SQL] Location %s is invalid: %s", target, e)) + String.format("[Flight SQL] Location %s is invalid: %s", uri, e)) .withCause(e); } Object quirks = parameters.get(PARAM_SQL_QUIRKS); diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java index cb6b3038f8..45b42df2ee 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java @@ -17,8 +17,11 @@ package org.apache.arrow.adbc.driver.flightsql; import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.core.ErrorDetail; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStatusCode; @@ -72,7 +75,24 @@ static AdbcStatusCode fromFlightStatusCode(FlightStatusCode code) { } static AdbcException fromFlightException(FlightRuntimeException e) { + List errorDetails = new ArrayList<>(); + for (String key : e.status().metadata().keys()) { + if (key.endsWith("-bin")) { + for (byte[] value : e.status().metadata().getAllByte(key)) { + errorDetails.add(new ErrorDetail(key, value)); + } + } else { + for (String value : e.status().metadata().getAll(key)) { + errorDetails.add(new ErrorDetail(key, value)); + } + } + } return new AdbcException( - e.getMessage(), e.getCause(), fromFlightStatusCode(e.status().code()), null, 0); + e.getMessage(), + e.getCause(), + fromFlightStatusCode(e.status().code()), + null, + 0, + errorDetails); } } diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java index 1fd8b910c7..e64508b4bf 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java @@ -247,6 +247,18 @@ public QueryResult executeQuery() throws AdbcException { new FlightInfoReader(allocator, client, clientCache, info.getEndpoints())); } + @Override + public Schema executeSchema() throws AdbcException { + if (bulkOperation != null) { + throw AdbcException.invalidState("[Flight SQL] Must executeUpdate() for bulk ingestion"); + } else if (sqlQuery == null) { + throw AdbcException.invalidState("[Flight SQL] Must setSqlQuery() before execute"); + } + return execute( + FlightSqlClient.PreparedStatement::getResultSetSchema, + (client) -> client.getExecuteSchema(sqlQuery).getSchema()); + } + @Override public UpdateResult executeUpdate() throws AdbcException { if (bulkOperation != null) { diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java new file mode 100644 index 0000000000..c617f664fa --- /dev/null +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.flightsql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.grpc.Metadata; +import io.grpc.Status; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcDriver; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcStatement; +import org.apache.arrow.adbc.core.ErrorDetail; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.ErrorFlightMetadata; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** Test that gRPC error details make it through. */ +class DetailsTest { + static BufferAllocator allocator; + static Producer producer; + static FlightServer server; + static AdbcDriver driver; + static AdbcDatabase database; + AdbcConnection connection; + AdbcStatement statement; + + @BeforeAll + static void beforeAll() throws Exception { + allocator = new RootAllocator(); + producer = new Producer(); + server = + FlightServer.builder() + .allocator(allocator) + .producer(producer) + .location(Location.forGrpcInsecure("localhost", 0)) + .build(); + server.start(); + driver = new FlightSqlDriver(allocator); + Map parameters = new HashMap<>(); + AdbcDriver.PARAM_URI.set( + parameters, Location.forGrpcInsecure("localhost", server.getPort()).getUri().toString()); + database = driver.open(parameters); + } + + @BeforeEach + void beforeEach() throws Exception { + connection = database.connect(); + statement = connection.createStatement(); + } + + @AfterEach + void afterEach() throws Exception { + AutoCloseables.close(statement, connection); + } + + @AfterAll + static void afterAll() throws Exception { + AutoCloseables.close(database, server, allocator); + } + + @Test + void flightDetails() throws Exception { + statement.setSqlQuery("flight"); + + AdbcException exception = + assertThrows( + AdbcException.class, + () -> { + try (AdbcStatement.QueryResult result = statement.executeQuery()) {} + }); + + assertThat(exception.getDetails()).contains(new ErrorDetail("x-foo", "text")); + Optional binaryKey = + exception.getDetails().stream().filter(x -> x.getKey().equals("x-foo-bin")).findAny(); + assertThat(binaryKey) + .get() + .extracting(ErrorDetail::getValue) + .isEqualTo("text".getBytes(StandardCharsets.UTF_8)); + } + + @Test + void grpcDetails() throws Exception { + statement.setSqlQuery("grpc"); + + AdbcException exception = + assertThrows( + AdbcException.class, + () -> { + try (AdbcStatement.QueryResult result = statement.executeQuery()) {} + }); + + assertThat(exception.getDetails()).contains(new ErrorDetail("x-foo", "text")); + Optional binaryKey = + exception.getDetails().stream().filter(x -> x.getKey().equals("x-foo-bin")).findAny(); + assertThat(binaryKey) + .get() + .extracting(ErrorDetail::getValue) + .isEqualTo("text".getBytes(StandardCharsets.UTF_8)); + } + + static class Producer implements FlightSqlProducer { + Metadata.Key BINARY_KEY = Metadata.Key.of("x-foo-bin", Metadata.BINARY_BYTE_MARSHALLER); + Metadata.Key TEXT_KEY = Metadata.Key.of("x-foo", Metadata.ASCII_STRING_MARSHALLER); + + @Override + public FlightInfo getFlightInfoStatement( + FlightSql.CommandStatementQuery commandStatementQuery, + CallContext callContext, + FlightDescriptor flightDescriptor) { + if (commandStatementQuery.getQuery().equals("flight")) { + // Using Flight path + ErrorFlightMetadata metadata = new ErrorFlightMetadata(); + metadata.insert("x-foo", "text"); + metadata.insert("x-foo-bin", "text".getBytes(StandardCharsets.UTF_8)); + throw CallStatus.UNKNOWN + .withDescription("Expected") + .withMetadata(metadata) + .toRuntimeException(); + } else if (commandStatementQuery.getQuery().equals("grpc")) { + // Using gRPC path + Metadata trailers = new Metadata(); + trailers.put(TEXT_KEY, "text"); + trailers.put(BINARY_KEY, "text".getBytes(StandardCharsets.UTF_8)); + throw Status.UNKNOWN.asRuntimeException(trailers); + } + + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + + // No-op implementations + + @Override + public void createPreparedStatement( + FlightSql.ActionCreatePreparedStatementRequest actionCreatePreparedStatementRequest, + CallContext callContext, + StreamListener streamListener) {} + + @Override + public void closePreparedStatement( + FlightSql.ActionClosePreparedStatementRequest actionClosePreparedStatementRequest, + CallContext callContext, + StreamListener streamListener) {} + + @Override + public FlightInfo getFlightInfoPreparedStatement( + FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public SchemaResult getSchemaStatement( + FlightSql.CommandStatementQuery commandStatementQuery, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamStatement( + FlightSql.TicketStatementQuery ticketStatementQuery, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void getStreamPreparedStatement( + FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public Runnable acceptPutStatement( + FlightSql.CommandStatementUpdate commandStatementUpdate, + CallContext callContext, + FlightStream flightStream, + StreamListener streamListener) { + return null; + } + + @Override + public Runnable acceptPutPreparedStatementUpdate( + FlightSql.CommandPreparedStatementUpdate commandPreparedStatementUpdate, + CallContext callContext, + FlightStream flightStream, + StreamListener streamListener) { + return null; + } + + @Override + public Runnable acceptPutPreparedStatementQuery( + FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, + FlightStream flightStream, + StreamListener streamListener) { + return null; + } + + @Override + public FlightInfo getFlightInfoSqlInfo( + FlightSql.CommandGetSqlInfo commandGetSqlInfo, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamSqlInfo( + FlightSql.CommandGetSqlInfo commandGetSqlInfo, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoTypeInfo( + FlightSql.CommandGetXdbcTypeInfo commandGetXdbcTypeInfo, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamTypeInfo( + FlightSql.CommandGetXdbcTypeInfo commandGetXdbcTypeInfo, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoCatalogs( + FlightSql.CommandGetCatalogs commandGetCatalogs, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamCatalogs( + CallContext callContext, ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoSchemas( + FlightSql.CommandGetDbSchemas commandGetDbSchemas, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamSchemas( + FlightSql.CommandGetDbSchemas commandGetDbSchemas, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoTables( + FlightSql.CommandGetTables commandGetTables, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamTables( + FlightSql.CommandGetTables commandGetTables, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoTableTypes( + FlightSql.CommandGetTableTypes commandGetTableTypes, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamTableTypes( + CallContext callContext, ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoPrimaryKeys( + FlightSql.CommandGetPrimaryKeys commandGetPrimaryKeys, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamPrimaryKeys( + FlightSql.CommandGetPrimaryKeys commandGetPrimaryKeys, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoExportedKeys( + FlightSql.CommandGetExportedKeys commandGetExportedKeys, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public FlightInfo getFlightInfoImportedKeys( + FlightSql.CommandGetImportedKeys commandGetImportedKeys, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public FlightInfo getFlightInfoCrossReference( + FlightSql.CommandGetCrossReference commandGetCrossReference, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamExportedKeys( + FlightSql.CommandGetExportedKeys commandGetExportedKeys, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void getStreamImportedKeys( + FlightSql.CommandGetImportedKeys commandGetImportedKeys, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void getStreamCrossReference( + FlightSql.CommandGetCrossReference commandGetCrossReference, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void close() throws Exception {} + + @Override + public void listFlights( + CallContext callContext, Criteria criteria, StreamListener streamListener) {} + } +} diff --git a/java/driver/jdbc-validation-derby/src/test/java/org/apache/arrow/adbc/driver/jdbc/derby/DerbyStatementTest.java b/java/driver/jdbc-validation-derby/src/test/java/org/apache/arrow/adbc/driver/jdbc/derby/DerbyStatementTest.java index 0f7138707d..9d80935cda 100644 --- a/java/driver/jdbc-validation-derby/src/test/java/org/apache/arrow/adbc/driver/jdbc/derby/DerbyStatementTest.java +++ b/java/driver/jdbc-validation-derby/src/test/java/org/apache/arrow/adbc/driver/jdbc/derby/DerbyStatementTest.java @@ -20,6 +20,7 @@ import java.nio.file.Path; import org.apache.arrow.adbc.driver.testsuite.AbstractStatementTest; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.io.TempDir; class DerbyStatementTest extends AbstractStatementTest { @@ -29,4 +30,8 @@ class DerbyStatementTest extends AbstractStatementTest { static void beforeAll() { quirks = new DerbyQuirks(tempDir); } + + @Override + @Disabled("Not supported") + public void executeSchemaParams() {} } diff --git a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java index ccce7db70d..fce9ff134d 100644 --- a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java +++ b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java @@ -37,6 +37,9 @@ public class PostgresqlQuirks extends SqlValidationQuirks { static final String POSTGRESQL_URL_ENV_VAR = "ADBC_JDBC_POSTGRESQL_URL"; static final String POSTGRESQL_USER_ENV_VAR = "ADBC_JDBC_POSTGRESQL_USER"; static final String POSTGRESQL_PASSWORD_ENV_VAR = "ADBC_JDBC_POSTGRESQL_PASSWORD"; + static final String POSTGRESQL_DATABASE_ENV_VAR = "ADBC_JDBC_POSTGRESQL_DATABASE"; + + String catalog = "postgres"; static String makeJdbcUrl() { final String postgresUrl = System.getenv(POSTGRESQL_URL_ENV_VAR); @@ -49,12 +52,21 @@ static String makeJdbcUrl() { return String.format("jdbc:postgresql://%s?user=%s&password=%s", postgresUrl, user, password); } + public Connection getJdbcConnection() throws SQLException { + return DriverManager.getConnection(makeJdbcUrl()); + } + @Override public AdbcDatabase initDatabase(BufferAllocator allocator) throws AdbcException { String url = makeJdbcUrl(); + final String catalog = System.getenv(POSTGRESQL_DATABASE_ENV_VAR); + Assumptions.assumeFalse( + catalog == null, "PostgreSQL catalog not found, set " + POSTGRESQL_DATABASE_ENV_VAR); + this.catalog = catalog; + final Map parameters = new HashMap<>(); - parameters.put(AdbcDriver.PARAM_URL, url); + AdbcDriver.PARAM_URI.set(parameters, url); parameters.put(JdbcDriver.PARAM_JDBC_QUIRKS, StandardJdbcQuirks.POSTGRESQL); return new JdbcDriver(allocator).open(parameters); } @@ -71,8 +83,12 @@ public void cleanupTable(String name) throws Exception { @Override public String defaultCatalog() { - // XXX: this should really come from configuration - return "postgres"; + return catalog; + } + + @Override + public String defaultDbSchema() { + return "public"; } @Override @@ -94,4 +110,9 @@ public TimeUnit defaultTimeUnit() { public TimeUnit defaultTimestampUnit() { return TimeUnit.MICROSECOND; } + + @Override + public boolean supportsCurrentCatalog() { + return true; + } } diff --git a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java new file mode 100644 index 0000000000..13ca0ee191 --- /dev/null +++ b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.jdbc.postgresql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.entry; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.Statement; +import java.util.Map; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.StandardStatistics; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.util.Text; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class StatisticsTest { + static PostgresqlQuirks quirks; + + @BeforeAll + static void beforeAll() { + quirks = new PostgresqlQuirks(); + } + + @Test + void adbc() throws Exception { + try (Connection connection = quirks.getJdbcConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate("DROP TABLE IF EXISTS adbcpkeytest"); + statement.executeUpdate("CREATE TABLE adbcpkeytest (key SERIAL PRIMARY KEY, value INT)"); + statement.executeUpdate("INSERT INTO adbcpkeytest (value) VALUES (0), (1), (2)"); + statement.executeUpdate("ANALYZE adbcpkeytest"); + } + + try (BufferAllocator allocator = new RootAllocator(); + AdbcDatabase database = quirks.initDatabase(allocator); + AdbcConnection connection = database.connect(); + ArrowReader reader = connection.getStatistics(null, null, "adbcpkeytest", true)) { + assertThat(reader.loadNextBatch()).isTrue(); + VectorSchemaRoot vsr = reader.getVectorSchemaRoot(); + assertThat(vsr.getRowCount()).isEqualTo(1); + + ListVector catalogDbSchemas = (ListVector) vsr.getVector(1); + assertThat(catalogDbSchemas.getValueCount()).isEqualTo(1); + + StructVector catalogDbSchema = (StructVector) catalogDbSchemas.getDataVector(); + ListVector dbSchemaStatistics = (ListVector) catalogDbSchema.getVectorById(1); + assertThat(dbSchemaStatistics.getValueCount()).isEqualTo(1); + + @SuppressWarnings("unchecked") + Map statistic = (Map) dbSchemaStatistics.getObject(0).get(0); + assertThat(statistic) + .contains( + entry("table_name", new Text("adbcpkeytest")), + entry("statistic_key", StandardStatistics.DISTINCT_COUNT.getKey()), + entry("statistic_value", 3L)); + + assertThat(reader.loadNextBatch()).isFalse(); + } + } + + /** Validate what PostgreSQL does. */ + @Test + void jdbc() throws Exception { + try (Connection connection = quirks.getJdbcConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate("DROP TABLE IF EXISTS adbcpkeytest"); + statement.executeUpdate("CREATE TABLE adbcpkeytest (key SERIAL PRIMARY KEY, value INT)"); + statement.executeUpdate("INSERT INTO adbcpkeytest (value) VALUES (0), (1), (2)"); + statement.executeUpdate("ANALYZE adbcpkeytest"); + + int count = 0; + try (ResultSet rs = + connection.getMetaData().getIndexInfo(null, null, "adbcpkeytest", false, true)) { + ResultSetMetaData rsmd = rs.getMetaData(); + while (rs.next()) { + // For debugging + for (int i = 1; i <= rsmd.getColumnCount(); i++) { + System.out.println(rsmd.getColumnName(i) + " => " + rs.getObject(i)); + } + System.out.println("==="); + + // TABLE_NAME + assertThat(rs.getString(3)).isEqualTo("adbcpkeytest"); + // TYPE + assertThat(rs.getShort(7)).isEqualTo(DatabaseMetaData.tableIndexOther); + // CARDINALITY + assertThat(rs.getLong(11)).isEqualTo(3); + + count++; + } + } + + assertThat(count).isEqualTo(1); + } + } +} diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java index 02c2ccac22..ae5ec226fc 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java @@ -25,10 +25,12 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.arrow.adbc.core.AdbcDriver; import org.apache.arrow.adbc.core.AdbcInfoCode; import org.apache.arrow.adbc.core.StandardSchemas; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -37,6 +39,7 @@ /** Helper class to track state needed to build up the info structure. */ final class InfoMetadataBuilder implements AutoCloseable { private static final byte STRING_VALUE_TYPE_ID = (byte) 0; + private static final byte BIGINT_VALUE_TYPE_ID = (byte) 2; private static final Map SUPPORTED_CODES = new HashMap<>(); private final Collection requestedCodes; private final DatabaseMetaData dbmd; @@ -45,6 +48,7 @@ final class InfoMetadataBuilder implements AutoCloseable { final UInt4Vector infoCodes; final DenseUnionVector infoValues; final VarCharVector stringValues; + final BigIntVector bigIntValues; @FunctionalInterface interface AddInfo { @@ -74,6 +78,11 @@ interface AddInfo { final String driverVersion = b.dbmd.getDriverVersion() + " (ADBC Driver Version 0.0.1)"; b.setStringValue(idx, driverVersion); }); + SUPPORTED_CODES.put( + AdbcInfoCode.DRIVER_ADBC_VERSION.getValue(), + (b, idx) -> { + b.setBigIntValue(idx, AdbcDriver.ADBC_VERSION_1_1_0); + }); } InfoMetadataBuilder(BufferAllocator allocator, Connection connection, int[] infoCodes) @@ -86,7 +95,18 @@ interface AddInfo { this.dbmd = connection.getMetaData(); this.infoCodes = (UInt4Vector) root.getVector(0); this.infoValues = (DenseUnionVector) root.getVector(1); - this.stringValues = this.infoValues.getVarCharVector((byte) 0); + this.stringValues = this.infoValues.getVarCharVector(STRING_VALUE_TYPE_ID); + this.bigIntValues = this.infoValues.getBigIntVector(BIGINT_VALUE_TYPE_ID); + } + + void setBigIntValue(int index, long value) { + infoValues.setValueCount(index + 1); + infoValues.setTypeId(index, BIGINT_VALUE_TYPE_ID); + bigIntValues.setSafe(index, value); + infoValues + .getOffsetBuffer() + .setInt((long) index * DenseUnionVector.OFFSET_WIDTH, bigIntValues.getValueCount()); + bigIntValues.setValueCount(bigIntValues.getValueCount() + 1); } void setStringValue(int index, final String value) { diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java index 1ddbf1c88a..aba972a9a2 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java @@ -42,12 +42,7 @@ public class JdbcArrowReader extends ArrowReader { JdbcArrowReader(BufferAllocator allocator, ResultSet resultSet, Schema overrideSchema) throws AdbcException { super(allocator); - final JdbcToArrowConfig config = - new JdbcToArrowConfigBuilder() - .setAllocator(allocator) - .setCalendar(JdbcToArrowUtils.getUtcCalendar()) - .setTargetBatchSize(1024) - .build(); + final JdbcToArrowConfig config = makeJdbcConfig(allocator); try { this.delegate = JdbcToArrow.sqlToArrowVectorIterator(resultSet, config); } catch (SQLException e) { @@ -75,6 +70,14 @@ public class JdbcArrowReader extends ArrowReader { } } + static JdbcToArrowConfig makeJdbcConfig(BufferAllocator allocator) { + return new JdbcToArrowConfigBuilder() + .setAllocator(allocator) + .setCalendar(JdbcToArrowUtils.getUtcCalendar()) + .setTargetBatchSize(1024) + .build(); + } + @Override public boolean loadNextBatch() { if (!delegate.hasNext()) return false; diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java index 398ef6d42e..8f66c154fa 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java @@ -21,7 +21,9 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.arrow.adbc.core.AdbcConnection; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatement; @@ -29,15 +31,24 @@ import org.apache.arrow.adbc.core.BulkIngestMode; import org.apache.arrow.adbc.core.IsolationLevel; import org.apache.arrow.adbc.core.StandardSchemas; +import org.apache.arrow.adbc.core.StandardStatistics; import org.apache.arrow.adbc.driver.jdbc.adapter.JdbcFieldInfoExtra; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.holders.NullableBigIntHolder; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; public class JdbcConnection implements AdbcConnection { private final BufferAllocator allocator; @@ -117,6 +128,165 @@ public ArrowReader getObjects( } } + static final class Statistic { + String table; + String column; + short key; + long value; + boolean multiColumn = false; + } + + @Override + public ArrowReader getStatistics( + String catalogPattern, String dbSchemaPattern, String tableNamePattern, boolean approximate) + throws AdbcException { + if (tableNamePattern == null) { + throw AdbcException.notImplemented( + JdbcDriverUtil.prefixExceptionMessage("getStatistics: must supply table name")); + } + + try (final VectorSchemaRoot root = + VectorSchemaRoot.create(StandardSchemas.GET_STATISTICS_SCHEMA, allocator); + ResultSet rs = + connection + .getMetaData() + .getIndexInfo( + catalogPattern, + dbSchemaPattern, + tableNamePattern, /*unique*/ + false, + approximate)) { + // Build up the statistics in-memory and then return a constant reader. + // We have to read and sort the data first because the ordering is not by the catalog/etc. + + // {catalog: {schema: {index_name: statistic}}} + Map>> allStatistics = new HashMap<>(); + + while (rs.next()) { + String catalog = rs.getString(1); + String schema = rs.getString(2); + String table = rs.getString(3); + String index = rs.getString(6); + short statisticType = rs.getShort(7); + String column = rs.getString(9); + long cardinality = rs.getLong(11); + + if (!allStatistics.containsKey(catalog)) { + allStatistics.put(catalog, new HashMap<>()); + } + + Map> catalogStats = allStatistics.get(catalog); + if (!catalogStats.containsKey(schema)) { + catalogStats.put(schema, new HashMap<>()); + } + + Map schemaStats = catalogStats.get(schema); + Statistic statistic = schemaStats.getOrDefault(index, new Statistic()); + if (schemaStats.containsKey(index)) { + // Multi-column index, ignore it + statistic.multiColumn = true; + continue; + } + + statistic.column = column; + statistic.table = table; + statistic.key = + statisticType == DatabaseMetaData.tableIndexStatistic + ? StandardStatistics.ROW_COUNT.getKey() + : StandardStatistics.DISTINCT_COUNT.getKey(); + statistic.value = cardinality; + schemaStats.put(index, statistic); + } + + VarCharVector catalogNames = (VarCharVector) root.getVector(0); + ListVector catalogDbSchemas = (ListVector) root.getVector(1); + StructVector dbSchemas = (StructVector) catalogDbSchemas.getDataVector(); + VarCharVector dbSchemaNames = (VarCharVector) dbSchemas.getVectorById(0); + ListVector dbSchemaStatistics = (ListVector) dbSchemas.getVectorById(1); + StructVector statistics = (StructVector) dbSchemaStatistics.getDataVector(); + VarCharVector tableNames = (VarCharVector) statistics.getVectorById(0); + VarCharVector columnNames = (VarCharVector) statistics.getVectorById(1); + SmallIntVector statisticKeys = (SmallIntVector) statistics.getVectorById(2); + DenseUnionVector statisticValues = (DenseUnionVector) statistics.getVectorById(3); + BitVector statisticIsApproximate = (BitVector) statistics.getVectorById(4); + + // Build up the Arrow result + Text text = new Text(); + NullableBigIntHolder holder = new NullableBigIntHolder(); + int catalogIndex = 0; + int schemaIndex = 0; + int statisticIndex = 0; + for (String catalog : allStatistics.keySet()) { + Map> schemas = allStatistics.get(catalog); + + if (catalog == null) { + catalogNames.setNull(catalogIndex); + } else { + text.set(catalog); + catalogNames.setSafe(catalogIndex, text); + } + catalogDbSchemas.startNewValue(catalogIndex); + + int schemaCount = 0; + for (String schema : schemas.keySet()) { + if (schema == null) { + dbSchemaNames.setNull(schemaIndex); + } else { + text.set(schema); + dbSchemaNames.setSafe(schemaIndex, text); + } + + dbSchemaStatistics.startNewValue(schemaIndex); + + Map indices = schemas.get(schema); + int statisticCount = 0; + for (Statistic statistic : indices.values()) { + if (statistic.multiColumn) { + continue; + } + + text.set(statistic.table); + tableNames.setSafe(statisticIndex, text); + if (statistic.column == null) { + columnNames.setNull(statisticIndex); + } else { + text.set(statistic.column); + columnNames.setSafe(statisticIndex, text); + } + statisticKeys.setSafe(statisticIndex, statistic.key); + statisticValues.setTypeId(statisticIndex, (byte) 0); + holder.isSet = 1; + holder.value = statistic.value; + statisticValues.setSafe(statisticIndex, holder); + statisticIsApproximate.setSafe(statisticIndex, approximate ? 1 : 0); + + statistics.setIndexDefined(statisticIndex++); + statisticCount++; + } + + dbSchemaStatistics.endValue(schemaIndex, statisticCount); + + dbSchemas.setIndexDefined(schemaIndex++); + schemaCount++; + } + + catalogDbSchemas.endValue(catalogIndex, schemaCount); + catalogIndex++; + } + root.setRowCount(catalogIndex); + + return RootArrowReader.fromRoot(allocator, root); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public ArrowReader getStatisticNames() throws AdbcException { + // TODO: + return AdbcConnection.super.getStatisticNames(); + } + @Override public Schema getTableSchema(String catalog, String dbSchema, String tableName) throws AdbcException { @@ -211,6 +381,42 @@ public void setAutoCommit(boolean enableAutoCommit) throws AdbcException { } } + @Override + public String getCurrentCatalog() throws AdbcException { + try { + return connection.getCatalog(); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public void setCurrentCatalog(String catalog) throws AdbcException { + try { + connection.setCatalog(catalog); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public String getCurrentDbSchema() throws AdbcException { + try { + return connection.getSchema(); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public void setCurrentDbSchema(String dbSchema) throws AdbcException { + try { + connection.setSchema(dbSchema); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + @Override public boolean getReadOnly() throws AdbcException { try { diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java index 95b3775f68..fd39e6d08b 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java @@ -30,6 +30,7 @@ import java.util.stream.LongStream; import org.apache.arrow.adapter.jdbc.JdbcFieldInfo; import org.apache.arrow.adapter.jdbc.JdbcParameterBinder; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig; import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatement; @@ -263,6 +264,41 @@ public QueryResult executeQuery() throws AdbcException { return new QueryResult(/*affectedRows=*/ -1, reader); } + @Override + public Schema executeSchema() throws AdbcException { + if (bulkOperation != null) { + throw AdbcException.invalidState("[JDBC] Call executeUpdate() for bulk operations"); + } else if (sqlQuery == null) { + throw AdbcException.invalidState("[JDBC] Must setSqlQuery() first"); + } + try { + invalidatePriorQuery(); + final PreparedStatement preparedStatement; + final PreparedStatement ownedStatement; + if (statement instanceof PreparedStatement) { + preparedStatement = (PreparedStatement) statement; + if (bindRoot != null) { + JdbcParameterBinder.builder(preparedStatement, bindRoot).bindAll().build().next(); + } + ownedStatement = null; + } else { + // new statement + preparedStatement = connection.prepareStatement(sqlQuery); + ownedStatement = preparedStatement; + } + + final JdbcToArrowConfig config = JdbcArrowReader.makeJdbcConfig(allocator); + final Schema schema = + JdbcToArrowUtils.jdbcToArrowSchema(preparedStatement.getMetaData(), config); + if (ownedStatement != null) { + ownedStatement.close(); + } + return schema; + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + @Override public Schema getParameterSchema() throws AdbcException { if (statement instanceof PreparedStatement) { diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java index 9915636d69..54e6059046 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java @@ -18,6 +18,7 @@ package org.apache.arrow.adbc.driver.testsuite; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; import org.apache.arrow.adbc.core.AdbcConnection; import org.apache.arrow.adbc.core.AdbcDatabase; @@ -48,6 +49,19 @@ public void afterEach() throws Exception { AutoCloseables.close(connection, database, allocator); } + @Test + void currentCatalog() throws Exception { + assumeThat(quirks.supportsCurrentCatalog()).isTrue(); + + assertThat(connection.getCurrentCatalog()).isEqualTo(quirks.defaultCatalog()); + connection.setCurrentCatalog(quirks.defaultCatalog()); + assertThat(connection.getCurrentCatalog()).isEqualTo(quirks.defaultCatalog()); + + assertThat(connection.getCurrentDbSchema()).isEqualTo(quirks.defaultDbSchema()); + connection.setCurrentDbSchema(quirks.defaultDbSchema()); + assertThat(connection.getCurrentDbSchema()).isEqualTo(quirks.defaultDbSchema()); + } + @Test void multipleConnections() throws Exception { try (final AdbcConnection ignored = database.connect()) {} diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java index e7a1a5743a..4d9184a4bb 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java @@ -19,6 +19,7 @@ import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertField; import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertRoot; +import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertSchema; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -239,6 +240,62 @@ public void bulkIngestCreateConflict() throws Exception { } } + @Test + public void executeSchema() throws Exception { + util.ingestTableIntsStrs(allocator, connection, tableName); + final String name = quirks.caseFoldColumnName("STRS"); + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery("SELECT " + name + " FROM " + tableName); + final Schema actualSchema = stmt.executeSchema(); + assertSchema(actualSchema) + .isEqualTo( + new Schema( + Collections.singletonList( + Field.nullable(name, Types.MinorType.VARCHAR.getType())))); + } + } + + @Test + public void executeSchemaPrepared() throws Exception { + util.ingestTableIntsStrs(allocator, connection, tableName); + final String name = quirks.caseFoldColumnName("STRS"); + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery("SELECT " + name + " FROM " + tableName); + stmt.prepare(); + final Schema actualSchema = stmt.executeSchema(); + assertSchema(actualSchema) + .isEqualTo( + new Schema( + Collections.singletonList( + Field.nullable(name, Types.MinorType.VARCHAR.getType())))); + } + } + + @Test + public void executeSchemaParams() throws Exception { + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery("SELECT ? AS FOO"); + stmt.prepare(); + Schema actualSchema = stmt.executeSchema(); + // Actual type unknown + assertThat(actualSchema.getFields().size()).isEqualTo(1); + + final Schema schema = + new Schema( + Collections.singletonList( + Field.nullable( + quirks.caseFoldColumnName("foo"), Types.MinorType.VARCHAR.getType()))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + ((VarCharVector) root.getVector(0)).setSafe(0, "foo".getBytes(StandardCharsets.UTF_8)); + root.setRowCount(1); + stmt.bind(root); + + actualSchema = stmt.executeSchema(); + assertSchema(actualSchema).isEqualTo(schema); + } + } + } + @Test public void prepareQuery() throws Exception { final Schema expectedSchema = util.ingestTableIntsStrs(allocator, connection, tableName); diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java index 120ecab255..a5da97f658 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java @@ -35,6 +35,11 @@ public void cleanupTable(String name) throws Exception {} /** Get the name of the default catalog. */ public abstract String defaultCatalog(); + /** Get the name of the default schema. */ + public String defaultDbSchema() { + return ""; + } + /** Normalize a table name. */ public String caseFoldTableName(String name) { return name; @@ -110,4 +115,8 @@ public ArrowType defaultTimeType() { public TimeUnit defaultTimestampUnit() { return TimeUnit.MILLISECOND; } + + public boolean supportsCurrentCatalog() { + return false; + } } diff --git a/python/adbc_driver_manager/MANIFEST.in b/python/adbc_driver_manager/MANIFEST.in index fe98528271..306c31144f 100644 --- a/python/adbc_driver_manager/MANIFEST.in +++ b/python/adbc_driver_manager/MANIFEST.in @@ -22,4 +22,7 @@ include NOTICE.txt include adbc_driver_manager/adbc.h include adbc_driver_manager/adbc_driver_manager.cc include adbc_driver_manager/adbc_driver_manager.h +include adbc_driver_manager/_lib.pxd +include adbc_driver_manager/_lib.pyi +include adbc_driver_manager/_reader.pyi include adbc_driver_manager/py.typed diff --git a/python/adbc_driver_manager/adbc_driver_manager/__init__.py b/python/adbc_driver_manager/adbc_driver_manager/__init__.py index e2eaee5701..25b821eb80 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/__init__.py +++ b/python/adbc_driver_manager/adbc_driver_manager/__init__.py @@ -90,6 +90,8 @@ class DatabaseOptions(enum.Enum): #: Set the password to use for username-password authentication. PASSWORD = "password" + #: The URI to connect to. + URI = "uri" #: Set the username to use for username-password authentication. USERNAME = "username" @@ -100,6 +102,10 @@ class ConnectionOptions(enum.Enum): Not all drivers support all options. """ + #: Get/set the current catalog. + CURRENT_CATALOG = "adbc.connection.catalog" + #: Get/set the current schema. + CURRENT_DB_SCHEMA = "adbc.connection.db_schema" #: Set the transaction isolation level. ISOLATION_LEVEL = "adbc.connection.transaction.isolation_level" @@ -110,7 +116,11 @@ class StatementOptions(enum.Enum): Not all drivers support all options. """ + #: Enable incremental execution on ExecutePartitions. + INCREMENTAL = "adbc.statement.exec.incremental" #: For bulk ingestion, whether to create or append to the table. INGEST_MODE = INGEST_OPTION_MODE #: For bulk ingestion, the table to ingest into. INGEST_TARGET_TABLE = INGEST_OPTION_TARGET_TABLE + #: Get progress of a query. + PROGRESS = "adbc.statement.exec.progress" diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd new file mode 100644 index 0000000000..88a61a66ec --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t + + +cdef extern from "adbc.h" nogil: + # C ABI + cdef struct CArrowSchema"ArrowSchema": + pass + cdef struct CArrowArray"ArrowArray": + pass + + ctypedef int (*CArrowArrayStreamGetLastError)(void*) + ctypedef int (*CArrowArrayStreamGetNext)(void*, CArrowArray*) + ctypedef char* (*CArrowArrayStreamGetSchema)(void*, CArrowSchema*) + ctypedef void (*CArrowArrayStreamRelease)(void*) + + cdef struct CArrowArrayStream"ArrowArrayStream": + CArrowArrayStreamGetLastError get_last_error + CArrowArrayStreamGetNext get_next + CArrowArrayStreamGetSchema get_schema + CArrowArrayStreamRelease release + + # ADBC + ctypedef uint8_t CAdbcStatusCode"AdbcStatusCode" + cdef CAdbcStatusCode ADBC_STATUS_OK + cdef CAdbcStatusCode ADBC_STATUS_UNKNOWN + cdef CAdbcStatusCode ADBC_STATUS_NOT_IMPLEMENTED + cdef CAdbcStatusCode ADBC_STATUS_NOT_FOUND + cdef CAdbcStatusCode ADBC_STATUS_ALREADY_EXISTS + cdef CAdbcStatusCode ADBC_STATUS_INVALID_ARGUMENT + cdef CAdbcStatusCode ADBC_STATUS_INVALID_STATE + cdef CAdbcStatusCode ADBC_STATUS_INVALID_DATA + cdef CAdbcStatusCode ADBC_STATUS_INTEGRITY + cdef CAdbcStatusCode ADBC_STATUS_INTERNAL + cdef CAdbcStatusCode ADBC_STATUS_IO + cdef CAdbcStatusCode ADBC_STATUS_CANCELLED + cdef CAdbcStatusCode ADBC_STATUS_TIMEOUT + cdef CAdbcStatusCode ADBC_STATUS_UNAUTHENTICATED + cdef CAdbcStatusCode ADBC_STATUS_UNAUTHORIZED + + cdef const char* ADBC_OPTION_VALUE_DISABLED + cdef const char* ADBC_OPTION_VALUE_ENABLED + + cdef const char* ADBC_CONNECTION_OPTION_AUTOCOMMIT + cdef const char* ADBC_INGEST_OPTION_TARGET_TABLE + cdef const char* ADBC_INGEST_OPTION_MODE + cdef const char* ADBC_INGEST_OPTION_MODE_APPEND + cdef const char* ADBC_INGEST_OPTION_MODE_CREATE + cdef const char* ADBC_INGEST_OPTION_MODE_REPLACE + cdef const char* ADBC_INGEST_OPTION_MODE_CREATE_APPEND + + cdef int ADBC_OBJECT_DEPTH_ALL + cdef int ADBC_OBJECT_DEPTH_CATALOGS + cdef int ADBC_OBJECT_DEPTH_DB_SCHEMAS + cdef int ADBC_OBJECT_DEPTH_TABLES + cdef int ADBC_OBJECT_DEPTH_COLUMNS + + cdef uint32_t ADBC_INFO_VENDOR_NAME + cdef uint32_t ADBC_INFO_VENDOR_VERSION + cdef uint32_t ADBC_INFO_VENDOR_ARROW_VERSION + cdef uint32_t ADBC_INFO_DRIVER_NAME + cdef uint32_t ADBC_INFO_DRIVER_VERSION + cdef uint32_t ADBC_INFO_DRIVER_ARROW_VERSION + + ctypedef void (*CAdbcErrorRelease)(CAdbcError*) + ctypedef void (*CAdbcPartitionsRelease)(CAdbcPartitions*) + + cdef struct CAdbcError"AdbcError": + char* message + int32_t vendor_code + char[5] sqlstate + CAdbcErrorRelease release + + cdef struct CAdbcErrorDetail"AdbcErrorDetail": + char* key + uint8_t* value + size_t value_length + + int AdbcErrorGetDetailCount(CAdbcError* error) + CAdbcErrorDetail AdbcErrorGetDetail(CAdbcError* error, int index) + CAdbcError* AdbcErrorFromArrayStream(CArrowArrayStream*, CAdbcStatusCode*) + + cdef int ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA + + cdef struct CAdbcDriver"AdbcDriver": + pass + + cdef struct CAdbcDatabase"AdbcDatabase": + void* private_data + + cdef struct CAdbcConnection"AdbcConnection": + void* private_data + + cdef struct CAdbcStatement"AdbcStatement": + void* private_data + + cdef struct CAdbcPartitions"AdbcPartitions": + size_t num_partitions + const uint8_t** partitions + const size_t* partition_lengths + void* private_data + CAdbcPartitionsRelease release + + CAdbcStatusCode AdbcDatabaseNew(CAdbcDatabase* database, CAdbcError* error) + CAdbcStatusCode AdbcDatabaseGetOption( + CAdbcDatabase*, const char*, char*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseGetOptionBytes( + CAdbcDatabase*, const char*, uint8_t*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseGetOptionDouble( + CAdbcDatabase*, const char*, double*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseGetOptionInt( + CAdbcDatabase*, const char*, int64_t*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseSetOption( + CAdbcDatabase*, const char*, const char*, CAdbcError*) + CAdbcStatusCode AdbcDatabaseSetOptionBytes( + CAdbcDatabase*, const char*, const uint8_t*, size_t, CAdbcError*) + CAdbcStatusCode AdbcDatabaseSetOptionDouble( + CAdbcDatabase*, const char*, double, CAdbcError*) + CAdbcStatusCode AdbcDatabaseSetOptionInt( + CAdbcDatabase*, const char*, int64_t, CAdbcError*) + CAdbcStatusCode AdbcDatabaseInit(CAdbcDatabase* database, CAdbcError* error) + CAdbcStatusCode AdbcDatabaseRelease(CAdbcDatabase* database, CAdbcError* error) + + ctypedef void (*CAdbcDriverInitFunc "AdbcDriverInitFunc")(int, void*, CAdbcError*) + CAdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc( + CAdbcDatabase* database, + CAdbcDriverInitFunc init_func, + CAdbcError* error) + + CAdbcStatusCode AdbcConnectionCancel(CAdbcConnection*, CAdbcError*) + CAdbcStatusCode AdbcConnectionCommit( + CAdbcConnection* connection, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionRollback( + CAdbcConnection* connection, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionReadPartition( + CAdbcConnection* connection, + const uint8_t* serialized_partition, + size_t serialized_length, + CArrowArrayStream* out, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionGetInfo( + CAdbcConnection* connection, + const uint32_t* info_codes, + size_t info_codes_length, + CArrowArrayStream* stream, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionGetObjects( + CAdbcConnection* connection, + int depth, + const char* catalog, + const char* db_schema, + const char* table_name, + const char** table_type, + const char* column_name, + CArrowArrayStream* stream, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionGetOption( + CAdbcConnection*, const char*, char*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetOptionBytes( + CAdbcConnection*, const char*, uint8_t*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetOptionDouble( + CAdbcConnection*, const char*, double*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetOptionInt( + CAdbcConnection*, const char*, int64_t*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetStatistics( + CAdbcConnection*, const char*, const char*, const char*, + char, CArrowArrayStream*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetStatisticNames( + CAdbcConnection*, CArrowArrayStream*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetTableSchema( + CAdbcConnection* connection, + const char* catalog, + const char* db_schema, + const char* table_name, + CArrowSchema* schema, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionGetTableTypes( + CAdbcConnection* connection, + CArrowArrayStream* stream, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionInit( + CAdbcConnection* connection, + CAdbcDatabase* database, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionNew( + CAdbcConnection* connection, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionRelease( + CAdbcConnection* connection, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionSetOption( + CAdbcConnection*, const char*, const char*, CAdbcError*) + CAdbcStatusCode AdbcConnectionSetOptionBytes( + CAdbcConnection*, const char*, const uint8_t*, size_t, CAdbcError*) + CAdbcStatusCode AdbcConnectionSetOptionDouble( + CAdbcConnection*, const char*, double, CAdbcError*) + CAdbcStatusCode AdbcConnectionSetOptionInt( + CAdbcConnection*, const char*, int64_t, CAdbcError*) + CAdbcStatusCode AdbcStatementBind( + CAdbcStatement* statement, + CArrowArray*, + CArrowSchema*, + CAdbcError* error) + + CAdbcStatusCode AdbcStatementCancel(CAdbcStatement*, CAdbcError*) + CAdbcStatusCode AdbcStatementBindStream( + CAdbcStatement* statement, + CArrowArrayStream*, + CAdbcError* error) + CAdbcStatusCode AdbcStatementExecutePartitions( + CAdbcStatement* statement, + CArrowSchema* schema, CAdbcPartitions* partitions, + int64_t* rows_affected, + CAdbcError* error) + CAdbcStatusCode AdbcStatementExecuteQuery( + CAdbcStatement* statement, + CArrowArrayStream* out, int64_t* rows_affected, + CAdbcError* error) + CAdbcStatusCode AdbcStatementExecuteSchema( + CAdbcStatement*, CArrowSchema*, CAdbcError*) + CAdbcStatusCode AdbcStatementGetOption( + CAdbcStatement*, const char*, char*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetOptionBytes( + CAdbcStatement*, const char*, uint8_t*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetOptionDouble( + CAdbcStatement*, const char*, double*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetOptionInt( + CAdbcStatement*, const char*, int64_t*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetParameterSchema( + CAdbcStatement* statement, + CArrowSchema* schema, + CAdbcError* error); + CAdbcStatusCode AdbcStatementNew( + CAdbcConnection* connection, + CAdbcStatement* statement, + CAdbcError* error) + CAdbcStatusCode AdbcStatementPrepare( + CAdbcStatement* statement, + CAdbcError* error) + CAdbcStatusCode AdbcStatementSetOption( + CAdbcStatement*, const char*, const char*, CAdbcError*) + CAdbcStatusCode AdbcStatementSetOptionBytes( + CAdbcStatement*, const char*, const uint8_t*, size_t, CAdbcError*) + CAdbcStatusCode AdbcStatementSetOptionDouble( + CAdbcStatement*, const char*, double, CAdbcError*) + CAdbcStatusCode AdbcStatementSetOptionInt( + CAdbcStatement*, const char*, int64_t, CAdbcError*) + CAdbcStatusCode AdbcStatementSetSqlQuery( + CAdbcStatement* statement, + const char* query, + CAdbcError* error) + CAdbcStatusCode AdbcStatementSetSubstraitPlan( + CAdbcStatement* statement, + const uint8_t* plan, + size_t length, + CAdbcError* error) + CAdbcStatusCode AdbcStatementRelease( + CAdbcStatement* statement, + CAdbcError* error) + +cdef const CAdbcError* PyAdbcErrorFromArrayStream( + CArrowArrayStream* stream, CAdbcStatusCode* status) + +cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except * + +cdef extern from "adbc_driver_manager.h": + const char* CAdbcStatusCodeMessage"AdbcStatusCodeMessage"(CAdbcStatusCode code) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi index 8f107369ea..28e6be16af 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi @@ -26,15 +26,28 @@ import typing INGEST_OPTION_MODE: str INGEST_OPTION_MODE_APPEND: str INGEST_OPTION_MODE_CREATE: str +INGEST_OPTION_MODE_CREATE_APPEND: str +INGEST_OPTION_MODE_REPLACE: str INGEST_OPTION_TARGET_TABLE: str class AdbcConnection(_AdbcHandle): def __init__(self, database: "AdbcDatabase", **kwargs: str) -> None: ... + def cancel(self) -> None: ... def close(self) -> None: ... def commit(self) -> None: ... def get_info( self, info_codes: Optional[List[Union[int, "AdbcInfoCode"]]] = None ) -> "ArrowArrayStreamHandle": ... + def get_option( + self, + key: Union[bytes, str], + *, + encoding: str = "utf-8", + errors: str = "strict", + ) -> str: ... + def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ... + def get_option_float(self, key: Union[bytes, str]) -> float: ... + def get_option_int(self, key: Union[bytes, str]) -> int: ... def get_objects( self, depth: "GetObjectsDepth", @@ -54,12 +67,22 @@ class AdbcConnection(_AdbcHandle): def read_partition(self, partition: bytes) -> "ArrowArrayStreamHandle": ... def rollback(self) -> None: ... def set_autocommit(self, enabled: bool) -> None: ... - def set_options(self, **kwargs: str) -> None: ... + def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... class AdbcDatabase(_AdbcHandle): def __init__(self, **kwargs: str) -> None: ... def close(self) -> None: ... - def set_options(self, **kwargs: str) -> None: ... + def get_option( + self, + key: Union[bytes, str], + *, + encoding: str = "utf-8", + errors: str = "strict", + ) -> str: ... + def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ... + def get_option_float(self, key: Union[bytes, str]) -> float: ... + def get_option_int(self, key: Union[bytes, str]) -> int: ... + def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... class AdbcInfoCode(enum.IntEnum): DRIVER_ARROW_VERSION = ... @@ -73,13 +96,25 @@ class AdbcStatement(_AdbcHandle): def __init__(self, *args, **kwargs) -> None: ... def bind(self, *args, **kwargs) -> Any: ... def bind_stream(self, *args, **kwargs) -> Any: ... + def cancel(self) -> None: ... def close(self) -> None: ... def execute_partitions(self, *args, **kwargs) -> Any: ... def execute_query(self, *args, **kwargs) -> Any: ... + def execute_schema(self) -> "ArrowSchemaHandle": ... def execute_update(self, *args, **kwargs) -> Any: ... + def get_option( + self, + key: Union[bytes, str], + *, + encoding: str = "utf-8", + errors: str = "strict", + ) -> str: ... + def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ... + def get_option_float(self, key: Union[bytes, str]) -> float: ... + def get_option_int(self, key: Union[bytes, str]) -> int: ... def get_parameter_schema(self, *args, **kwargs) -> Any: ... def prepare(self, *args, **kwargs) -> Any: ... - def set_options(self, *args, **kwargs) -> Any: ... + def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... def set_sql_query(self, *args, **kwargs) -> Any: ... def set_substrait_plan(self, *args, **kwargs) -> Any: ... def __reduce__(self) -> Any: ... @@ -118,6 +153,7 @@ class Error(Exception): status_code: AdbcStatusCode vendor_code: Optional[int] sqlstate: Optional[str] + details: List[Tuple[str, bytes]] def __init__( self, @@ -125,7 +161,8 @@ class Error(Exception): *, status_code: Union[int, AdbcStatusCode], vendor_code: Optional[str] = None, - sqlstate: Optional[str] = None + sqlstate: Optional[str] = None, + details: Optional[List[Tuple[str, bytes]]] = None, ) -> None: ... class GetObjectsDepth(enum.IntEnum): @@ -145,7 +182,8 @@ class NotSupportedError(DatabaseError): message: str, *, vendor_code: Optional[str] = None, - sqlstate: Optional[str] = None + sqlstate: Optional[str] = None, + details: Optional[List[Tuple[str, bytes]]] = None, ) -> None: ... class OperationalError(DatabaseError): ... diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index a096148821..e21130a710 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -34,205 +34,6 @@ if typing.TYPE_CHECKING: from typing import Self -cdef extern from "adbc.h" nogil: - # C ABI - cdef struct CArrowSchema"ArrowSchema": - pass - cdef struct CArrowArray"ArrowArray": - pass - cdef struct CArrowArrayStream"ArrowArrayStream": - pass - - # ADBC - ctypedef uint8_t CAdbcStatusCode"AdbcStatusCode" - cdef CAdbcStatusCode ADBC_STATUS_OK - cdef CAdbcStatusCode ADBC_STATUS_UNKNOWN - cdef CAdbcStatusCode ADBC_STATUS_NOT_IMPLEMENTED - cdef CAdbcStatusCode ADBC_STATUS_NOT_FOUND - cdef CAdbcStatusCode ADBC_STATUS_ALREADY_EXISTS - cdef CAdbcStatusCode ADBC_STATUS_INVALID_ARGUMENT - cdef CAdbcStatusCode ADBC_STATUS_INVALID_STATE - cdef CAdbcStatusCode ADBC_STATUS_INVALID_DATA - cdef CAdbcStatusCode ADBC_STATUS_INTEGRITY - cdef CAdbcStatusCode ADBC_STATUS_INTERNAL - cdef CAdbcStatusCode ADBC_STATUS_IO - cdef CAdbcStatusCode ADBC_STATUS_CANCELLED - cdef CAdbcStatusCode ADBC_STATUS_TIMEOUT - cdef CAdbcStatusCode ADBC_STATUS_UNAUTHENTICATED - cdef CAdbcStatusCode ADBC_STATUS_UNAUTHORIZED - - cdef const char* ADBC_OPTION_VALUE_DISABLED - cdef const char* ADBC_OPTION_VALUE_ENABLED - - cdef const char* ADBC_CONNECTION_OPTION_AUTOCOMMIT - cdef const char* ADBC_INGEST_OPTION_TARGET_TABLE - cdef const char* ADBC_INGEST_OPTION_MODE - cdef const char* ADBC_INGEST_OPTION_MODE_APPEND - cdef const char* ADBC_INGEST_OPTION_MODE_CREATE - - cdef int ADBC_OBJECT_DEPTH_ALL - cdef int ADBC_OBJECT_DEPTH_CATALOGS - cdef int ADBC_OBJECT_DEPTH_DB_SCHEMAS - cdef int ADBC_OBJECT_DEPTH_TABLES - cdef int ADBC_OBJECT_DEPTH_COLUMNS - - cdef uint32_t ADBC_INFO_VENDOR_NAME - cdef uint32_t ADBC_INFO_VENDOR_VERSION - cdef uint32_t ADBC_INFO_VENDOR_ARROW_VERSION - cdef uint32_t ADBC_INFO_DRIVER_NAME - cdef uint32_t ADBC_INFO_DRIVER_VERSION - cdef uint32_t ADBC_INFO_DRIVER_ARROW_VERSION - - ctypedef void (*CAdbcErrorRelease)(CAdbcError*) - ctypedef void (*CAdbcPartitionsRelease)(CAdbcPartitions*) - - cdef struct CAdbcError"AdbcError": - char* message - int32_t vendor_code - char[5] sqlstate - CAdbcErrorRelease release - - cdef struct CAdbcDriver"AdbcDriver": - pass - - cdef struct CAdbcDatabase"AdbcDatabase": - void* private_data - - cdef struct CAdbcConnection"AdbcConnection": - void* private_data - - cdef struct CAdbcStatement"AdbcStatement": - void* private_data - - cdef struct CAdbcPartitions"AdbcPartitions": - size_t num_partitions - const uint8_t** partitions - const size_t* partition_lengths - void* private_data - CAdbcPartitionsRelease release - - CAdbcStatusCode AdbcDatabaseNew(CAdbcDatabase* database, CAdbcError* error) - CAdbcStatusCode AdbcDatabaseSetOption( - CAdbcDatabase* database, - const char* key, - const char* value, - CAdbcError* error) - CAdbcStatusCode AdbcDatabaseInit(CAdbcDatabase* database, CAdbcError* error) - CAdbcStatusCode AdbcDatabaseRelease(CAdbcDatabase* database, CAdbcError* error) - - ctypedef void (*CAdbcDriverInitFunc "AdbcDriverInitFunc")(int, void*, CAdbcError*) - CAdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc( - CAdbcDatabase* database, - CAdbcDriverInitFunc init_func, - CAdbcError* error) - - CAdbcStatusCode AdbcConnectionCommit( - CAdbcConnection* connection, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionRollback( - CAdbcConnection* connection, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionReadPartition( - CAdbcConnection* connection, - const uint8_t* serialized_partition, - size_t serialized_length, - CArrowArrayStream* out, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionGetInfo( - CAdbcConnection* connection, - uint32_t* info_codes, - size_t info_codes_length, - CArrowArrayStream* stream, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionGetObjects( - CAdbcConnection* connection, - int depth, - const char* catalog, - const char* db_schema, - const char* table_name, - const char** table_type, - const char* column_name, - CArrowArrayStream* stream, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionGetTableSchema( - CAdbcConnection* connection, - const char* catalog, - const char* db_schema, - const char* table_name, - CArrowSchema* schema, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionGetTableTypes( - CAdbcConnection* connection, - CArrowArrayStream* stream, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionInit( - CAdbcConnection* connection, - CAdbcDatabase* database, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionNew( - CAdbcConnection* connection, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionSetOption( - CAdbcConnection* connection, - const char* key, - const char* value, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionRelease( - CAdbcConnection* connection, - CAdbcError* error) - - CAdbcStatusCode AdbcStatementBind( - CAdbcStatement* statement, - CArrowArray*, - CArrowSchema*, - CAdbcError* error) - CAdbcStatusCode AdbcStatementBindStream( - CAdbcStatement* statement, - CArrowArrayStream*, - CAdbcError* error) - CAdbcStatusCode AdbcStatementExecutePartitions( - CAdbcStatement* statement, - CArrowSchema* schema, CAdbcPartitions* partitions, - int64_t* rows_affected, - CAdbcError* error) - CAdbcStatusCode AdbcStatementExecuteQuery( - CAdbcStatement* statement, - CArrowArrayStream* out, int64_t* rows_affected, - CAdbcError* error) - CAdbcStatusCode AdbcStatementGetParameterSchema( - CAdbcStatement* statement, - CArrowSchema* schema, - CAdbcError* error); - CAdbcStatusCode AdbcStatementNew( - CAdbcConnection* connection, - CAdbcStatement* statement, - CAdbcError* error) - CAdbcStatusCode AdbcStatementPrepare( - CAdbcStatement* statement, - CAdbcError* error) - CAdbcStatusCode AdbcStatementSetOption( - CAdbcStatement* statement, - const char* key, - const char* value, - CAdbcError* error) - CAdbcStatusCode AdbcStatementSetSqlQuery( - CAdbcStatement* statement, - const char* query, - CAdbcError* error) - CAdbcStatusCode AdbcStatementSetSubstraitPlan( - CAdbcStatement* statement, - const uint8_t* plan, - size_t length, - CAdbcError* error) - CAdbcStatusCode AdbcStatementRelease( - CAdbcStatement* statement, - CAdbcError* error) - - -cdef extern from "adbc_driver_manager.h": - const char* CAdbcStatusCodeMessage"AdbcStatusCodeMessage"(CAdbcStatusCode code) - - class AdbcStatusCode(enum.IntEnum): """ A status code indicating the type of error. @@ -282,13 +83,16 @@ class Error(Exception): A vendor-specific status code if present. sqlstate : str, optional The SQLSTATE code if present. + details : list[tuple[str, bytes]], optional + Additional error details, if present. """ - def __init__(self, message, *, status_code, vendor_code=None, sqlstate=None): + def __init__(self, message, *, status_code, vendor_code=None, sqlstate=None, details=None): super().__init__(message) self.status_code = AdbcStatusCode(status_code) self.vendor_code = vendor_code self.sqlstate = sqlstate + self.details = details or [] class InterfaceError(Error): @@ -322,12 +126,13 @@ class ProgrammingError(DatabaseError): class NotSupportedError(DatabaseError): """An operation or some functionality is not supported.""" - def __init__(self, message, *, vendor_code=None, sqlstate=None): + def __init__(self, message, *, vendor_code=None, sqlstate=None, details=None): super().__init__( message, status_code=AdbcStatusCode.NOT_IMPLEMENTED, vendor_code=vendor_code, sqlstate=sqlstate, + details=details, ) @@ -348,10 +153,14 @@ NotSupportedError.__module__ = "adbc_driver_manager" INGEST_OPTION_MODE = ADBC_INGEST_OPTION_MODE.decode("utf-8") INGEST_OPTION_MODE_APPEND = ADBC_INGEST_OPTION_MODE_APPEND.decode("utf-8") INGEST_OPTION_MODE_CREATE = ADBC_INGEST_OPTION_MODE_CREATE.decode("utf-8") +INGEST_OPTION_MODE_REPLACE = ADBC_INGEST_OPTION_MODE_REPLACE.decode("utf-8") +INGEST_OPTION_MODE_CREATE_APPEND = ADBC_INGEST_OPTION_MODE_CREATE_APPEND.decode("utf-8") INGEST_OPTION_TARGET_TABLE = ADBC_INGEST_OPTION_TARGET_TABLE.decode("utf-8") cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *: + cdef CAdbcErrorDetail c_detail + if status == ADBC_STATUS_OK: return @@ -362,14 +171,26 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *: if error != NULL: if error.message != NULL: message += ": " - message += error.message.decode("utf-8") - if error.vendor_code: + message += error.message.decode("utf-8", "replace") + if error.vendor_code and error.vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA: vendor_code = error.vendor_code message += f". Vendor code: {vendor_code}" if error.sqlstate[0] != 0: sqlstate = bytes(error.sqlstate[i] for i in range(5)) - sqlstate = sqlstate.decode("ascii") + sqlstate = sqlstate.decode("ascii", "replace") message += f". SQLSTATE: {sqlstate}" + + num_details = AdbcErrorGetDetailCount(error) + details = [] + for index in range(num_details): + c_detail = AdbcErrorGetDetail(error, index) + if c_detail.key == NULL or c_detail.value == NULL: + # Shouldn't happen... + break + details.append( + (c_detail.key, + PyBytes_FromStringAndSize( c_detail.value, c_detail.value_length))) + if error.release: error.release(error) @@ -395,13 +216,15 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *: ADBC_STATUS_UNAUTHORIZED): klass = ProgrammingError elif status == ADBC_STATUS_NOT_IMPLEMENTED: - raise NotSupportedError(message, vendor_code=vendor_code, sqlstate=sqlstate) - raise klass(message, status_code=status, vendor_code=vendor_code, sqlstate=sqlstate) + raise NotSupportedError(message, vendor_code=vendor_code, sqlstate=sqlstate, details=details) + raise klass(message, status_code=status, vendor_code=vendor_code, sqlstate=sqlstate, details=details) cdef CAdbcError empty_error(): cdef CAdbcError error memset(&error, 0, cython.sizeof(error)) + # We always want the extended error info + error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA return error @@ -521,6 +344,11 @@ class GetObjectsDepth(enum.IntEnum): COLUMNS = ADBC_OBJECT_DEPTH_COLUMNS +# Assume a driver won't return more than 128 MiB of option data at +# once. +_MAX_OPTION_SIZE = 2**27 + + cdef class AdbcDatabase(_AdbcHandle): """ An instance of a database. @@ -556,8 +384,8 @@ cdef class AdbcDatabase(_AdbcHandle): elif value is None: raise ValueError(f"value for key '{key}' cannot be None") else: - key = key.encode("utf-8") - value = value.encode("utf-8") + key = _to_bytes(key, "key") + value = _to_bytes(value, "value") c_key = key c_value = value status = AdbcDatabaseSetOption( @@ -581,8 +409,115 @@ cdef class AdbcDatabase(_AdbcHandle): status = AdbcDatabaseRelease(&self.database, &c_error) check_error(status, &c_error) + def get_option( + self, + key: str | bytes, + *, + encoding="utf-8", + errors="strict", + ) -> str: + """ + Get the value of a string option. + + Parameters + ---------- + key : str or bytes + The option to get. + encoding : str + The encoding of the option value. This should almost + always be UTF-8. + errors : str + What to do about errors when decoding the option value + (see bytes.decode). + """ + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcDatabaseGetOption( + &self.database, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + # Remove trailing null terminator + if c_len > 0: + c_len -= 1 + return buf[:c_len].decode(encoding, errors) + + def get_option_bytes(self, key: str) -> bytes: + """Get the value of a binary option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcDatabaseGetOptionBytes( + &self.database, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + return bytes(buf[:c_len]) + + def get_option_float(self, key: str) -> float: + """Get the value of a floating-point option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef double c_value = 0.0 + check_error( + AdbcDatabaseGetOptionDouble( + &self.database, c_key, &c_value, &c_error), + &c_error) + return c_value + + def get_option_int(self, key: str) -> int: + """Get the value of an integer option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef int64_t c_value = 0 + check_error( + AdbcDatabaseGetOptionInt( + &self.database, c_key, &c_value, &c_error), + &c_error) + return c_value + def set_options(self, **kwargs) -> None: - """Set arbitrary key-value options. + """ + Set arbitrary key-value options. Pass options as kwargs: ``set_options(**{"some.option": "value"})``. @@ -591,23 +526,38 @@ cdef class AdbcDatabase(_AdbcHandle): See Also -------- adbc_driver_manager.DatabaseOptions : Standard option names. - """ cdef CAdbcError c_error = empty_error() cdef char* c_key = NULL cdef char* c_value = NULL for key, value in kwargs.items(): - key = key.encode("utf-8") + key = _to_bytes(key, "option key") c_key = key if value is None: c_value = NULL - else: - value = value.encode("utf-8") + status = AdbcDatabaseSetOption( + &self.database, c_key, c_value, &c_error) + elif isinstance(value, str): + value = _to_bytes(value, "option value") + c_value = value + status = AdbcDatabaseSetOption( + &self.database, c_key, c_value, &c_error) + elif isinstance(value, bytes): c_value = value + status = AdbcDatabaseSetOptionBytes( + &self.database, c_key, c_value, len(value), &c_error) + elif isinstance(value, float): + status = AdbcDatabaseSetOptionDouble( + &self.database, c_key, value, &c_error) + elif isinstance(value, int): + status = AdbcDatabaseSetOptionInt( + &self.database, c_key, value, &c_error) + else: + raise ValueError( + f"Unsupported type {type(value)} for value {value!r} " + f"of option {key}") - status = AdbcDatabaseSetOption( - &self.database, c_key, c_value, &c_error) check_error(status, &c_error) @@ -661,6 +611,14 @@ cdef class AdbcConnection(_AdbcHandle): database._open_child() + def cancel(self) -> None: + """Attempt to cancel any ongoing operations on the connection.""" + cdef CAdbcError c_error = empty_error() + cdef CAdbcStatusCode status + with nogil: + status = AdbcConnectionCancel(&self.connection, &c_error) + check_error(status, &c_error) + def commit(self) -> None: """Commit the current transaction.""" cdef CAdbcError c_error = empty_error() @@ -749,6 +707,112 @@ cdef class AdbcConnection(_AdbcHandle): return stream + def get_option( + self, + key: str | bytes, + *, + encoding="utf-8", + errors="strict", + ) -> str: + """ + Get the value of a string option. + + Parameters + ---------- + key : str or bytes + The option to get. + encoding : str + The encoding of the option value. This should almost + always be UTF-8. + errors : str + What to do about errors when decoding the option value + (see bytes.decode). + """ + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcConnectionGetOption( + &self.connection, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + # Remove trailing null terminator + if c_len > 0: + c_len -= 1 + return buf[:c_len].decode(encoding, errors) + + def get_option_bytes(self, key: str) -> bytes: + """Get the value of a binary option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcConnectionGetOptionBytes( + &self.connection, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + return bytes(buf[:c_len]) + + def get_option_float(self, key: str) -> float: + """Get the value of a floating-point option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef double c_value = 0.0 + check_error( + AdbcConnectionGetOptionDouble( + &self.connection, c_key, &c_value, &c_error), + &c_error) + return c_value + + def get_option_int(self, key: str) -> int: + """Get the value of an integer option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef int64_t c_value = 0 + check_error( + AdbcConnectionGetOptionInt( + &self.connection, c_key, &c_value, &c_error), + &c_error) + return c_value + def get_table_schema(self, catalog, db_schema, table_name) -> ArrowSchemaHandle: """ Get the Arrow schema of a table. @@ -839,7 +903,8 @@ cdef class AdbcConnection(_AdbcHandle): check_error(status, &c_error) def set_options(self, **kwargs) -> None: - """Set arbitrary key-value options. + """ + Set arbitrary key-value options. Pass options as kwargs: ``set_options(**{"some.option": "value"})``. @@ -853,17 +918,33 @@ cdef class AdbcConnection(_AdbcHandle): cdef char* c_key = NULL cdef char* c_value = NULL for key, value in kwargs.items(): - key = key.encode("utf-8") + key = _to_bytes(key, "option key") c_key = key if value is None: c_value = NULL - else: - value = value.encode("utf-8") + status = AdbcConnectionSetOption( + &self.connection, c_key, c_value, &c_error) + elif isinstance(value, str): + value = _to_bytes(value, "option value") c_value = value + status = AdbcConnectionSetOption( + &self.connection, c_key, c_value, &c_error) + elif isinstance(value, bytes): + c_value = value + status = AdbcConnectionSetOptionBytes( + &self.connection, c_key, c_value, len(value), &c_error) + elif isinstance(value, float): + status = AdbcConnectionSetOptionDouble( + &self.connection, c_key, value, &c_error) + elif isinstance(value, int): + status = AdbcConnectionSetOptionInt( + &self.connection, c_key, value, &c_error) + else: + raise ValueError( + f"Unsupported type {type(value)} for value {value!r} " + f"of option {key}") - status = AdbcConnectionSetOption( - &self.connection, c_key, c_value, &c_error) check_error(status, &c_error) def close(self) -> None: @@ -974,7 +1055,16 @@ cdef class AdbcStatement(_AdbcHandle): &c_error) check_error(status, &c_error) + def cancel(self) -> None: + """Attempt to cancel any ongoing operations on the connection.""" + cdef CAdbcError c_error = empty_error() + cdef CAdbcStatusCode status + with nogil: + status = AdbcStatementCancel(&self.statement, &c_error) + check_error(status, &c_error) + def close(self) -> None: + """Release the handle to the statement.""" cdef CAdbcError c_error = empty_error() cdef CAdbcStatusCode status self.connection._close_child() @@ -1048,6 +1138,25 @@ cdef class AdbcStatement(_AdbcHandle): return (partitions, schema, rows_affected) + def execute_schema(self) -> ArrowSchemaHandle: + """ + Get the schema of the result set without executing the query. + + Returns + ------- + ArrowSchemaHandle + The schema of the result set. + """ + cdef CAdbcError c_error = empty_error() + cdef ArrowSchemaHandle schema = ArrowSchemaHandle() + with nogil: + status = AdbcStatementExecuteSchema( + &self.statement, + &schema.schema, + &c_error) + check_error(status, &c_error) + return schema + def execute_update(self) -> int: """ Execute the query without a result set. @@ -1068,6 +1177,112 @@ cdef class AdbcStatement(_AdbcHandle): check_error(status, &c_error) return rows_affected + def get_option( + self, + key: str | bytes, + *, + encoding="utf-8", + errors="strict", + ) -> str: + """ + Get the value of a string option. + + Parameters + ---------- + key : str or bytes + The option to get. + encoding : str + The encoding of the option value. This should almost + always be UTF-8. + errors : str + What to do about errors when decoding the option value + (see bytes.decode). + """ + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcStatementGetOption( + &self.statement, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + # Remove trailing null terminator + if c_len > 0: + c_len -= 1 + return buf[:c_len].decode(encoding, errors) + + def get_option_bytes(self, key: str) -> bytes: + """Get the value of a binary option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcStatementGetOptionBytes( + &self.statement, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + return bytes(buf[:c_len]) + + def get_option_float(self, key: str) -> float: + """Get the value of a floating-point option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef double c_value = 0.0 + check_error( + AdbcStatementGetOptionDouble( + &self.statement, c_key, &c_value, &c_error), + &c_error) + return c_value + + def get_option_int(self, key: str) -> int: + """Get the value of an integer option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef int64_t c_value = 0 + check_error( + AdbcStatementGetOptionInt( + &self.statement, c_key, &c_value, &c_error), + &c_error) + return c_value + def get_parameter_schema(self) -> ArrowSchemaHandle: """Get the Arrow schema for bound parameters. @@ -1112,6 +1327,8 @@ cdef class AdbcStatement(_AdbcHandle): Pass options as kwargs: ``set_options(**{"some.option": "value"})``. + Note, not all drivers support setting options after creation. + See Also -------- adbc_driver_manager.StatementOptions : Standard option names. @@ -1120,17 +1337,33 @@ cdef class AdbcStatement(_AdbcHandle): cdef char* c_key = NULL cdef char* c_value = NULL for key, value in kwargs.items(): - key = key.encode("utf-8") + key = _to_bytes(key, "option key") c_key = key if value is None: c_value = NULL - else: - value = value.encode("utf-8") + status = AdbcStatementSetOption( + &self.statement, c_key, c_value, &c_error) + elif isinstance(value, str): + value = _to_bytes(value, "option value") + c_value = value + status = AdbcStatementSetOption( + &self.statement, c_key, c_value, &c_error) + elif isinstance(value, bytes): c_value = value + status = AdbcStatementSetOptionBytes( + &self.statement, c_key, c_value, len(value), &c_error) + elif isinstance(value, float): + status = AdbcStatementSetOptionDouble( + &self.statement, c_key, value, &c_error) + elif isinstance(value, int): + status = AdbcStatementSetOptionInt( + &self.statement, c_key, value, &c_error) + else: + raise ValueError( + f"Unsupported type {type(value)} for value {value!r} " + f"of option {key}") - status = AdbcStatementSetOption( - &self.statement, c_key, c_value, &c_error) check_error(status, &c_error) def set_sql_query(self, str query not None) -> None: @@ -1152,3 +1385,8 @@ cdef class AdbcStatement(_AdbcHandle): status = AdbcStatementSetSubstraitPlan( &self.statement, c_plan, length, &c_error) check_error(status, &c_error) + + +cdef const CAdbcError* PyAdbcErrorFromArrayStream( + CArrowArrayStream* stream, CAdbcStatusCode* status): + return AdbcErrorFromArrayStream(stream, status) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_reader.pyi b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyi new file mode 100644 index 0000000000..48fc718dd4 --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyi @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import typing + +import pandas +import pyarrow + +class AdbcRecordBatchReader(pyarrow.RecordBatchReader): + def close(self) -> None: ... + def read_all(self) -> pyarrow.Table: ... + def read_next_batch(self) -> pyarrow.RecordBatch: ... + def read_pandas(self, **kwargs) -> pandas.DataFrame: ... + @property + def schema(self) -> pyarrow.Schema: ... + @classmethod + def _import_from_c(cls, address: int) -> AdbcRecordBatchReader: ... + def __enter__(self) -> AdbcRecordBatchReader: ... + def __exit__(self, type, value, traceback) -> None: ... + def __iter__(self) -> typing.Iterator[pyarrow.RecordBatch]: ... diff --git a/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx new file mode 100644 index 0000000000..43de674023 --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +import pyarrow +from cython.operator cimport dereference as deref +from libc.stdint cimport uintptr_t + +from ._lib cimport * + + +cdef class _AdbcErrorHelper: + cdef: + CArrowArrayStream c_stream + + def check_error(self, exception): + cdef: + CAdbcStatusCode c_status + const CAdbcError* error = PyAdbcErrorFromArrayStream(&self.c_stream, &c_status) + CAdbcErrorDetail detail + + if error != NULL: + check_error(c_status, error) + + raise exception + + +# Can't directly inherit pyarrow.RecordBatchReader, but we want to in order to +# pass isinstance checks +class AdbcRecordBatchReader(pyarrow.RecordBatchReader): + def __init__(self, reader, helper): + self._reader = reader + self._helper = helper + + def close(self): + self._reader.close() + + @classmethod + def _import_from_c(cls, address) -> AdbcRecordBatchReader: + cdef: + CArrowArrayStream* c_stream = address + _AdbcErrorHelper helper = _AdbcErrorHelper.__new__(_AdbcErrorHelper) + # Save a copy of the stream to use + helper.c_stream = deref(c_stream) + try: + reader = pyarrow.RecordBatchReader._import_from_c(int(address)) + except Exception as e: + helper.check_error(e) + return cls(reader, helper) + + def __enter__(self): + return self + + def __exit__(self, exc_info, exc_val, exc_tb): + pass + + def __iter__(self): + try: + yield from self._reader + except Exception as e: + self._helper.check_error(e) + + def iter_batches_with_custom_metadata(self): + try: + return self._reader.iter_batches_with_custom_metadata() + except Exception as e: + self._helper.check_error(e) + + def read_all(self): + try: + return self._reader.read_all() + except Exception as e: + self._helper.check_error(e) + + def read_next_batch(self, *args, **kwargs): + try: + return self._reader.read_next_batch(*args, **kwargs) + except Exception as e: + self._helper.check_error(e) + + def read_next_batch_with_custom_metadata(self): + try: + return self._reader.read_next_batch_with_custom_metadata() + except Exception as e: + self._helper.check_error(e) + + def read_pandas(self, *args, **kwargs): + try: + return self._reader.read_pandas(*args, **kwargs) + except Exception as e: + self._helper.check_error(e) + + @property + def schema(self): + try: + return self._reader.schema + except Exception as e: + self._helper.check_error(e) diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 60bc2d1bce..f28fbf0ee7 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -43,7 +43,9 @@ except ImportError as e: raise ImportError("PyArrow is required for the DBAPI-compatible interface") from e -from . import _lib +import adbc_driver_manager + +from . import _lib, _reader if typing.TYPE_CHECKING: import pandas @@ -78,6 +80,7 @@ 100: "driver_name", 101: "driver_version", 102: "driver_arrow_version", + 103: "driver_adbc_version", } # ---------------------------------------------------------- @@ -344,6 +347,16 @@ def __del__(self) -> None: # API Extensions # ------------------------------------------------------------ + def adbc_cancel(self) -> None: + """ + Cancel any ongoing operations on this connection. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + self._conn.cancel() + def adbc_clone(self) -> "Connection": """ Create a new Connection sharing the same underlying database. @@ -479,6 +492,40 @@ def adbc_connection(self) -> _lib.AdbcConnection: """ return self._conn + @property + def adbc_current_catalog(self) -> str: + """ + The name of the current catalog. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + key = adbc_driver_manager.ConnectionOptions.CURRENT_CATALOG.value + return self._conn.get_option(key) + + @adbc_current_catalog.setter + def adbc_current_catalog(self, catalog: str) -> None: + key = adbc_driver_manager.ConnectionOptions.CURRENT_CATALOG.value + self._conn.set_options(**{key: catalog}) + + @property + def adbc_current_db_schema(self) -> str: + """ + The name of the current schema. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + key = adbc_driver_manager.ConnectionOptions.CURRENT_DB_SCHEMA.value + return self._conn.get_option(key) + + @adbc_current_db_schema.setter + def adbc_current_db_schema(self, db_schema: str) -> None: + key = adbc_driver_manager.ConnectionOptions.CURRENT_DB_SCHEMA.value + self._conn.set_options(**{key: db_schema}) + @property def adbc_database(self) -> _lib.AdbcDatabase: """ @@ -621,7 +668,8 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None: self._prepare_execute(operation, parameters) handle, self._rowcount = self._stmt.execute_query() self._results = _RowIterator( - pyarrow.RecordBatchReader._import_from_c(handle.address) + # pyarrow.RecordBatchReader._import_from_c(handle.address) + _reader.AdbcRecordBatchReader._import_from_c(handle.address) ) def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None: @@ -729,11 +777,21 @@ def __next__(self): # API Extensions # ------------------------------------------------------------ + def adbc_cancel(self) -> None: + """ + Cancel any ongoing operations on this statement. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + self._stmt.cancel() + def adbc_ingest( self, table_name: str, data: Union[pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader], - mode: Literal["append", "create"] = "create", + mode: Literal["append", "create", "replace", "create_append"] = "create", ) -> int: """ Ingest Arrow data into a database table. @@ -748,7 +806,12 @@ def adbc_ingest( data The Arrow data to insert. mode - Whether to append data to an existing table, or create a new table. + How to deal with existing data: + + - 'append': append to a table (error if table does not exist) + - 'create': create a table and insert (error if table exists) + - 'create_append': create a table (if not exists) and insert + - 'replace': drop existing table (if any), then same as 'create' Returns ------- @@ -764,6 +827,10 @@ def adbc_ingest( c_mode = _lib.INGEST_OPTION_MODE_APPEND elif mode == "create": c_mode = _lib.INGEST_OPTION_MODE_CREATE + elif mode == "create_append": + c_mode = _lib.INGEST_OPTION_MODE_CREATE_APPEND + elif mode == "replace": + c_mode = _lib.INGEST_OPTION_MODE_REPLACE else: raise ValueError(f"Invalid value for 'mode': {mode}") self._stmt.set_options( @@ -810,6 +877,23 @@ def adbc_execute_partitions( partitions, schema, self._rowcount = self._stmt.execute_partitions() return partitions, pyarrow.Schema._import_from_c(schema.address) + def adbc_execute_schema(self, operation, parameters=None) -> pyarrow.Schema: + """ + Get the schema of the result set of a query without executing it. + + Returns + ------- + pyarrow.Schema + The schema of the result set. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + self._prepare_execute(operation, parameters) + schema = self._stmt.execute_schema() + return pyarrow.Schema._import_from_c(schema.address) + def adbc_prepare(self, operation: Union[bytes, str]) -> Optional[pyarrow.Schema]: """ Prepare a query without executing it. diff --git a/python/adbc_driver_manager/setup.py b/python/adbc_driver_manager/setup.py index dde05a08e3..075ae8612d 100644 --- a/python/adbc_driver_manager/setup.py +++ b/python/adbc_driver_manager/setup.py @@ -72,10 +72,15 @@ def get_version_and_cmdclass(pkg_path): # ------------------------------------------------------------ # Resolve compiler flags +build_type = os.environ.get("ADBC_BUILD_TYPE", "release") + if sys.platform == "win32": extra_compile_args = ["/std:c++17", "/DADBC_EXPORTING"] else: extra_compile_args = ["-std=c++17"] + if build_type == "debug": + # Useful to step through driver manager code in GDB + extra_compile_args.extend(["-ggdb", "-Og"]) # ------------------------------------------------------------ # Setup @@ -92,7 +97,16 @@ def get_version_and_cmdclass(pkg_path): "adbc_driver_manager/_lib.pyx", "adbc_driver_manager/adbc_driver_manager.cc", ], - ) + ), + Extension( + name="adbc_driver_manager._reader", + extra_compile_args=extra_compile_args, + include_dirs=[str(source_root.joinpath("adbc_driver_manager").resolve())], + language="c++", + sources=[ + "adbc_driver_manager/_reader.pyx", + ], + ), ], version=version, ) diff --git a/python/adbc_driver_manager/tests/test_reader.py b/python/adbc_driver_manager/tests/test_reader.py new file mode 100644 index 0000000000..54d2c7d524 --- /dev/null +++ b/python/adbc_driver_manager/tests/test_reader.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pyarrow +import pytest + +from adbc_driver_manager._lib import ArrowArrayStreamHandle +from adbc_driver_manager._reader import AdbcRecordBatchReader + +schema = pyarrow.schema([("ints", "int32")]) +batches = [ + pyarrow.record_batch([[1, 2, 3, 4]], schema=schema), +] + + +def _make_reader(): + original = pyarrow.RecordBatchReader.from_batches(schema, batches) + exported = ArrowArrayStreamHandle() + original._export_to_c(exported.address) + return AdbcRecordBatchReader._import_from_c(exported.address) + + +def test_reader(): + wrapped = _make_reader() + assert wrapped.read_next_batch() == batches[0] + + +def test_reader_error(): + schema = pyarrow.schema([("ints", "int32")]) + + def batches(): + yield pyarrow.record_batch([[1, 2, 3, 4]], schema=schema) + raise ValueError("foo") + + original = pyarrow.RecordBatchReader.from_batches(schema, batches()) + exported = ArrowArrayStreamHandle() + original._export_to_c(exported.address) + wrapped = AdbcRecordBatchReader._import_from_c(exported.address) + + assert wrapped.read_next_batch() is not None + with pytest.raises(pyarrow.ArrowInvalid): + wrapped.read_next_batch() + + +def test_reader_methods(): + with _make_reader() as reader: + assert reader.read_all() == pyarrow.Table.from_batches(batches, schema) + + with _make_reader() as reader: + assert reader.read_pandas() is not None + + with _make_reader() as reader: + for batch in reader: + assert batch == batches[0] + + with _make_reader() as reader: + with pytest.raises(NotImplementedError): + assert reader.read_next_batch_with_custom_metadata() is not None + + with _make_reader() as reader: + with pytest.raises(NotImplementedError): + for batch in reader.iter_batches_with_custom_metadata(): + assert batch == batches[0] + + with _make_reader() as reader: + assert reader.schema == schema diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py b/python/adbc_driver_postgresql/tests/test_dbapi.py index 80339d6039..f516ba0e82 100644 --- a/python/adbc_driver_postgresql/tests/test_dbapi.py +++ b/python/adbc_driver_postgresql/tests/test_dbapi.py @@ -17,6 +17,7 @@ from typing import Generator +import pyarrow import pytest from adbc_driver_postgresql import StatementOptions, dbapi @@ -28,6 +29,32 @@ def postgres(postgres_uri: str) -> Generator[dbapi.Connection, None, None]: yield conn +def test_conn_current_catalog(postgres: dbapi.Connection) -> None: + assert postgres.adbc_current_catalog != "" + + +def test_conn_current_db_schema(postgres: dbapi.Connection) -> None: + assert postgres.adbc_current_db_schema == "public" + + +def test_conn_change_db_schema(postgres: dbapi.Connection) -> None: + assert postgres.adbc_current_db_schema == "public" + + with postgres.cursor() as cur: + cur.execute("CREATE SCHEMA IF NOT EXISTS dbapischema") + + assert postgres.adbc_current_db_schema == "public" + postgres.adbc_current_db_schema = "dbapischema" + assert postgres.adbc_current_db_schema == "dbapischema" + + +def test_conn_get_info(postgres: dbapi.Connection) -> None: + info = postgres.adbc_get_info() + assert info["driver_name"] == "ADBC PostgreSQL Driver" + assert info["driver_adbc_version"] == 1_001_000 + assert info["vendor_name"] == "PostgreSQL" + + def test_query_batch_size(postgres: dbapi.Connection): with postgres.cursor() as cur: cur.execute("DROP TABLE IF EXISTS test_batch_size") @@ -47,6 +74,12 @@ def test_query_batch_size(postgres: dbapi.Connection): cur.adbc_statement.set_options( **{StatementOptions.BATCH_SIZE_HINT_BYTES.value: "1"} ) + assert ( + cur.adbc_statement.get_option_int( + StatementOptions.BATCH_SIZE_HINT_BYTES.value + ) + == 1 + ) cur.execute("SELECT * FROM test_batch_size") table = cur.fetch_arrow_table() assert len(table.to_batches()) == 65536 @@ -54,17 +87,128 @@ def test_query_batch_size(postgres: dbapi.Connection): cur.adbc_statement.set_options( **{StatementOptions.BATCH_SIZE_HINT_BYTES.value: "4096"} ) + assert ( + cur.adbc_statement.get_option_int( + StatementOptions.BATCH_SIZE_HINT_BYTES.value + ) + == 4096 + ) cur.execute("SELECT * FROM test_batch_size") table = cur.fetch_arrow_table() assert 64 <= len(table.to_batches()) <= 256 +def test_query_cancel(postgres: dbapi.Connection) -> None: + with postgres.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS test_batch_size") + cur.execute("CREATE TABLE test_batch_size (ints INT)") + cur.execute( + """ + INSERT INTO test_batch_size (ints) + SELECT generated :: INT + FROM GENERATE_SERIES(1, 1048576) temp(generated) + """ + ) + postgres.commit() + + # Ensure different ways of reading all raise the desired error + with postgres.cursor() as cur: + cur.execute("SELECT * FROM test_batch_size") + cur.adbc_cancel() + with pytest.raises(postgres.OperationalError, match="canceling statement"): + cur.fetchone() + + with postgres.cursor() as cur: + cur.execute("SELECT * FROM test_batch_size") + cur.adbc_cancel() + with pytest.raises(postgres.OperationalError, match="canceling statement"): + cur.fetch_arrow_table() + + with postgres.cursor() as cur: + cur.execute("SELECT * FROM test_batch_size") + cur.adbc_cancel() + with pytest.raises(postgres.OperationalError, match="canceling statement"): + cur.fetch_df() + + +def test_query_execute_schema(postgres: dbapi.Connection) -> None: + with postgres.cursor() as cur: + schema = cur.adbc_execute_schema("SELECT 1 AS foo") + assert schema == pyarrow.schema([("foo", "int32")]) + + +def test_query_invalid(postgres: dbapi.Connection) -> None: + with postgres.cursor() as cur: + with pytest.raises( + postgres.ProgrammingError, match="failed to prepare query" + ) as excinfo: + cur.execute("SELECT * FROM tabledoesnotexist") + + assert excinfo.value.sqlstate == "42P01" + assert len(excinfo.value.details) > 0 + + def test_query_trivial(postgres: dbapi.Connection): with postgres.cursor() as cur: cur.execute("SELECT 1") assert cur.fetchone() == (1,) +def test_stmt_ingest(postgres: dbapi.Connection) -> None: + table = pyarrow.table( + [ + [1, 2, 3], + ["a", None, "b"], + ], + names=["ints", "strs"], + ) + double_table = pyarrow.table( + [ + [1, 1, 2, 2, 3, 3], + ["a", "a", None, None, "b", "b"], + ], + names=["ints", "strs"], + ) + + with postgres.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS test_ingest") + + with pytest.raises( + postgres.ProgrammingError, match='"test_ingest" does not exist' + ): + cur.adbc_ingest("test_ingest", table, mode="append") + postgres.rollback() + + cur.adbc_ingest("test_ingest", table, mode="replace") + cur.execute("SELECT * FROM test_ingest ORDER BY ints") + assert cur.fetch_arrow_table() == table + + with pytest.raises( + postgres.ProgrammingError, match='"test_ingest" already exists' + ): + cur.adbc_ingest("test_ingest", table, mode="create") + + cur.adbc_ingest("test_ingest", table, mode="create_append") + cur.execute("SELECT * FROM test_ingest ORDER BY ints") + assert cur.fetch_arrow_table() == double_table + + cur.adbc_ingest("test_ingest", table, mode="replace") + cur.execute("SELECT * FROM test_ingest ORDER BY ints") + assert cur.fetch_arrow_table() == table + + cur.execute("DROP TABLE IF EXISTS test_ingest") + + cur.adbc_ingest("test_ingest", table, mode="create_append") + cur.execute("SELECT * FROM test_ingest ORDER BY ints") + assert cur.fetch_arrow_table() == table + + cur.execute("DROP TABLE IF EXISTS test_ingest") + + cur.adbc_ingest("test_ingest", table, mode="create") + cur.execute("SELECT * FROM test_ingest ORDER BY ints") + assert cur.fetch_arrow_table() == table + + def test_ddl(postgres: dbapi.Connection): with postgres.cursor() as cur: cur.execute("DROP TABLE IF EXISTS test_ddl") diff --git a/python/adbc_driver_postgresql/tests/test_lowlevel.py b/python/adbc_driver_postgresql/tests/test_lowlevel.py index d19022bc1c..1e79713b38 100644 --- a/python/adbc_driver_postgresql/tests/test_lowlevel.py +++ b/python/adbc_driver_postgresql/tests/test_lowlevel.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import collections.abc + import pyarrow import pytest @@ -23,13 +25,15 @@ @pytest.fixture -def postgres(postgres_uri: str) -> adbc_driver_manager.AdbcConnection: +def postgres( + postgres_uri: str, +) -> collections.abc.Generator[adbc_driver_manager.AdbcConnection, None, None]: with adbc_driver_postgresql.connect(postgres_uri) as db: with adbc_driver_manager.AdbcConnection(db) as conn: yield conn -def test_query_trivial(postgres): +def test_query_trivial(postgres: adbc_driver_manager.AdbcConnection) -> None: with adbc_driver_manager.AdbcStatement(postgres) as stmt: stmt.set_sql_query("SELECT 1") stream, _ = stmt.execute_query() @@ -37,11 +41,11 @@ def test_query_trivial(postgres): assert reader.read_all() -def test_version(): +def test_version() -> None: assert adbc_driver_postgresql.__version__ # type:ignore -def test_failed_connection(): +def test_failed_connection() -> None: with pytest.raises( adbc_driver_manager.OperationalError, match=".*libpq.*Failed to connect.*" ): diff --git a/r/adbcdrivermanager/src/driver_log.c b/r/adbcdrivermanager/src/driver_log.c index 543ae9a7a9..2565aad814 100644 --- a/r/adbcdrivermanager/src/driver_log.c +++ b/r/adbcdrivermanager/src/driver_log.c @@ -111,7 +111,8 @@ static AdbcStatusCode LogConnectionCommit(struct AdbcConnection* connection, } static AdbcStatusCode LogConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { Rprintf("LogConnectionGetInfo()\n"); diff --git a/r/adbcdrivermanager/src/driver_monkey.c b/r/adbcdrivermanager/src/driver_monkey.c index 9076975741..1eb88a51f4 100644 --- a/r/adbcdrivermanager/src/driver_monkey.c +++ b/r/adbcdrivermanager/src/driver_monkey.c @@ -105,7 +105,7 @@ static AdbcStatusCode MonkeyConnectionCommit(struct AdbcConnection* connection, } static AdbcStatusCode MonkeyConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { diff --git a/r/adbcdrivermanager/src/driver_void.c b/r/adbcdrivermanager/src/driver_void.c index 59cfe03607..a17ae8a330 100644 --- a/r/adbcdrivermanager/src/driver_void.c +++ b/r/adbcdrivermanager/src/driver_void.c @@ -105,7 +105,7 @@ static AdbcStatusCode VoidConnectionCommit(struct AdbcConnection* connection, } static AdbcStatusCode VoidConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { diff --git a/r/adbcflightsql/tools/download-go.R b/r/adbcflightsql/tools/download-go.R index 3a9c40c734..262d153aaa 100644 --- a/r/adbcflightsql/tools/download-go.R +++ b/r/adbcflightsql/tools/download-go.R @@ -17,7 +17,7 @@ tmp_dir <- "src/go/tmp" -go_version <- Sys.getenv("R_ADBC_GO_VERSION_DOWNLOAD", "1.19.9") +go_version <- Sys.getenv("R_ADBC_GO_VERSION_DOWNLOAD", "1.18.10") go_platform <- tolower(Sys.info()[["sysname"]]) if (!(go_platform %in% c("darwin", "linux", "windows"))) { diff --git a/r/adbcpostgresql/bootstrap.R b/r/adbcpostgresql/bootstrap.R index a68d888fa0..2670f7605f 100644 --- a/r/adbcpostgresql/bootstrap.R +++ b/r/adbcpostgresql/bootstrap.R @@ -26,6 +26,8 @@ files_to_vendor <- c( "../../c/driver/postgresql/statement.cc", "../../c/driver/postgresql/connection.h", "../../c/driver/postgresql/connection.cc", + "../../c/driver/postgresql/error.h", + "../../c/driver/postgresql/error.cc", "../../c/driver/postgresql/database.h", "../../c/driver/postgresql/database.cc", "../../c/driver/postgresql/postgresql.cc", diff --git a/r/adbcpostgresql/src/.gitignore b/r/adbcpostgresql/src/.gitignore index 34dd9749f4..cd8318e2e9 100644 --- a/r/adbcpostgresql/src/.gitignore +++ b/r/adbcpostgresql/src/.gitignore @@ -23,6 +23,8 @@ connection.cc connection.h database.h database.cc +error.h +error.cc postgresql.cc statement.h statement.cc diff --git a/r/adbcpostgresql/src/Makevars.in b/r/adbcpostgresql/src/Makevars.in index f90db0cbb7..8244768b95 100644 --- a/r/adbcpostgresql/src/Makevars.in +++ b/r/adbcpostgresql/src/Makevars.in @@ -19,6 +19,7 @@ PKG_CPPFLAGS=-I../src @cppflags@ -DADBC_EXPORT="" PKG_LIBS=@libs@ OBJECTS = init.o \ + error.o \ connection.o \ database.o \ statement.o \ diff --git a/r/adbcpostgresql/src/Makevars.ucrt b/r/adbcpostgresql/src/Makevars.ucrt index ebdc16fbd5..a72e984560 100644 --- a/r/adbcpostgresql/src/Makevars.ucrt +++ b/r/adbcpostgresql/src/Makevars.ucrt @@ -19,6 +19,7 @@ CRT=-ucrt include Makevars.win OBJECTS = init.o \ + error.o \ connection.o \ database.o \ statement.o \ diff --git a/r/adbcpostgresql/src/Makevars.win b/r/adbcpostgresql/src/Makevars.win index 0f03930540..331ef27424 100644 --- a/r/adbcpostgresql/src/Makevars.win +++ b/r/adbcpostgresql/src/Makevars.win @@ -22,6 +22,7 @@ PKG_LIBS = -L$(RWINLIB)/lib${R_ARCH}${CRT} \ -lpq -lpgport -lpgcommon -lssl -lcrypto -lwsock32 -lsecur32 -lws2_32 -lgdi32 -lcrypt32 -lwldap32 OBJECTS = init.o \ + error.o \ connection.o \ database.o \ statement.o \ diff --git a/r/adbcsnowflake/configure b/r/adbcsnowflake/configure index 2f3ae66426..095d0cf4e3 100755 --- a/r/adbcsnowflake/configure +++ b/r/adbcsnowflake/configure @@ -73,7 +73,7 @@ fi # On OSX we need -framework Security because of some dependency somewhere if [ `uname` = "Darwin" ]; then - PKG_LIBS="-framework Security $PKG_LIBS" + PKG_LIBS="-framework Security -lresolv $PKG_LIBS" fi PKG_LIBS="$PKG_LIBS $SYMBOL_ARGS" diff --git a/r/adbcsnowflake/tools/download-go.R b/r/adbcsnowflake/tools/download-go.R index 3a9c40c734..262d153aaa 100644 --- a/r/adbcsnowflake/tools/download-go.R +++ b/r/adbcsnowflake/tools/download-go.R @@ -17,7 +17,7 @@ tmp_dir <- "src/go/tmp" -go_version <- Sys.getenv("R_ADBC_GO_VERSION_DOWNLOAD", "1.19.9") +go_version <- Sys.getenv("R_ADBC_GO_VERSION_DOWNLOAD", "1.18.10") go_platform <- tolower(Sys.info()[["sysname"]]) if (!(go_platform %in% c("darwin", "linux", "windows"))) { From 055d58fcf8866c5d63d3fe04827e2d1e28787f26 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 29 Aug 2023 08:51:53 -0400 Subject: [PATCH 09/20] fix(go/adbc): don't include NUL in error messages (#998) Fixes #997. --- .github/workflows/integration.yml | 4 +- .github/workflows/java.yml | 2 +- .github/workflows/native-unix.yml | 4 +- c/driver/postgresql/README.md | 4 +- c/driver_manager/adbc_driver_manager.cc | 8 +- ci/docker/flightsql-test.dockerfile | 20 +++ docker-compose.yml | 36 ++-- go/adbc/adbc.go | 7 +- .../driver/flightsql/cmd/testserver/main.go | 161 ++++++++++++++++++ go/adbc/drivermgr/adbc_driver_manager.cc | 8 +- go/adbc/pkg/_tmpl/driver.go.tmpl | 2 +- go/adbc/pkg/_tmpl/utils.h.tmpl | 5 +- go/adbc/pkg/flightsql/driver.go | 2 +- go/adbc/pkg/flightsql/utils.h | 7 +- go/adbc/pkg/panicdummy/driver.go | 2 +- go/adbc/pkg/panicdummy/utils.h | 5 +- go/adbc/pkg/snowflake/driver.go | 2 +- go/adbc/pkg/snowflake/utils.h | 7 +- .../adbc_driver_flightsql/tests/conftest.py | 10 ++ .../adbc_driver_flightsql/tests/test_dbapi.py | 15 ++ 20 files changed, 282 insertions(+), 29 deletions(-) create mode 100644 ci/docker/flightsql-test.dockerfile create mode 100644 go/adbc/driver/flightsql/cmd/testserver/main.go diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index c293e0f0d6..c8cf5cdd45 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -137,7 +137,7 @@ jobs: - name: Start SQLite server and Dremio shell: bash -l {0} run: | - docker-compose up -d golang-sqlite-flightsql dremio dremio-init + docker-compose up -d flightsql-test flightsql-sqlite-test dremio dremio-init - name: Build FlightSQL Driver shell: bash -l {0} @@ -155,6 +155,7 @@ jobs: ADBC_DREMIO_FLIGHTSQL_USER: "dremio" ADBC_DREMIO_FLIGHTSQL_PASS: "dremio123" ADBC_SQLITE_FLIGHTSQL_URI: "grpc+tcp://localhost:8080" + ADBC_TEST_FLIGHTSQL_URI: "grpc+tcp://localhost:41414" run: | ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" ./ci/scripts/cpp_test.sh "$(pwd)/build" @@ -174,6 +175,7 @@ jobs: ADBC_DREMIO_FLIGHTSQL_URI: "grpc+tcp://localhost:32010" ADBC_DREMIO_FLIGHTSQL_USER: "dremio" ADBC_DREMIO_FLIGHTSQL_PASS: "dremio123" + ADBC_TEST_FLIGHTSQL_URI: "grpc+tcp://localhost:41414" run: | ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" - name: Stop SQLite server and Dremio diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 10b98e1c71..d955ab2d88 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -69,7 +69,7 @@ jobs: - name: Start SQLite server shell: bash -l {0} run: | - docker-compose up -d golang-sqlite-flightsql + docker-compose up -d flightsql-sqlite-test - name: Build/Test env: ADBC_SQLITE_FLIGHTSQL_URI: "grpc+tcp://localhost:8080" diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index b12aa46d51..e3a1bac3d1 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -658,7 +658,7 @@ jobs: if: matrix.config.pkg == 'adbcpostgresql' && runner.os == 'Linux' run: | cd r/adbcpostgresql - docker compose up --detach postgres_test + docker compose up --detach postgres-test ADBC_POSTGRESQL_TEST_URI="postgresql://localhost:5432/postgres?user=postgres&password=password" echo "ADBC_POSTGRESQL_TEST_URI=${ADBC_POSTGRESQL_TEST_URI}" >> $GITHUB_ENV @@ -666,7 +666,7 @@ jobs: if: matrix.config.pkg == 'adbcflightsql' && runner.os == 'Linux' run: | cd r/adbcpostgresql - docker compose up --detach golang-sqlite-flightsql + docker compose up --detach flightsql-sqlite-test ADBC_FLIGHTSQL_TEST_URI="grpc://localhost:8080" echo "ADBC_FLIGHTSQL_TEST_URI=${ADBC_FLIGHTSQL_TEST_URI}" >> $GITHUB_ENV diff --git a/c/driver/postgresql/README.md b/c/driver/postgresql/README.md index cc5a3dfe03..8ccffb6845 100644 --- a/c/driver/postgresql/README.md +++ b/c/driver/postgresql/README.md @@ -54,9 +54,9 @@ Alternatively use the `docker compose` provided by ADBC to manage the test database container. ```shell -$ docker compose up postgres_test +$ docker compose up postgres-test # When finished: -# docker compose down postgres_test +# docker compose down postgres-test ``` Then, to run the tests, set the environment variable specifying the diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index 516bf9bbf7..e4287534df 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -642,8 +642,12 @@ const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream return nullptr; } auto* private_data = reinterpret_cast(stream->private_data); - return private_data->private_driver->ErrorFromArrayStream(&private_data->stream, - status); + auto* error = + private_data->private_driver->ErrorFromArrayStream(&private_data->stream, status); + if (error) { + const_cast(error)->private_driver = private_data->private_driver; + } + return error; } #define INIT_ERROR(ERROR, SOURCE) \ diff --git a/ci/docker/flightsql-test.dockerfile b/ci/docker/flightsql-test.dockerfile new file mode 100644 index 0000000000..7c67b06533 --- /dev/null +++ b/ci/docker/flightsql-test.dockerfile @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +ARG GO +FROM golang:${GO} +EXPOSE 41414 diff --git a/docker-compose.yml b/docker-compose.yml index dd0ef2f53f..2c77d72198 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -107,15 +107,6 @@ services: ###################### Test database environments ############################ - postgres_test: - container_name: adbc_postgres_test - image: postgres:latest - environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: password - ports: - - "5432:5432" - dremio: container_name: adbc-dremio image: dremio/dremio-oss:latest @@ -150,7 +141,23 @@ services: volumes: - "./ci/scripts/integration/dremio:/init" - golang-sqlite-flightsql: + flightsql-test: + image: ${REPO}:adbc-flightsql-test + build: + context: . + cache_from: + - ${REPO}:adbc-flightsql-test + dockerfile: ci/docker/flightsql-test.dockerfile + args: + GO: ${GO} + ports: + - "41414:41414" + volumes: + - .:/adbc:delegated + command: >- + /bin/bash -c "cd /adbc/go/adbc && go run ./driver/flightsql/cmd/testserver -host 0.0.0.0 -port 41414" + + flightsql-sqlite-test: image: ${REPO}:golang-${GO}-sqlite-flightsql build: context: . @@ -162,3 +169,12 @@ services: ARROW_MAJOR_VERSION: ${ARROW_MAJOR_VERSION} ports: - 8080:8080 + + postgres-test: + container_name: adbc_postgres_test + image: postgres:latest + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + ports: + - "5432:5432" diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go index ad6194f240..b0737fe02a 100644 --- a/go/adbc/adbc.go +++ b/go/adbc/adbc.go @@ -131,7 +131,12 @@ type Error struct { } func (e Error) Error() string { - return fmt.Sprintf("%s: SqlState: %s, msg: %s", e.Code, string(e.SqlState[:]), e.Msg) + // Don't include a NUL in the string since C Data Interface uses char* (and + // don't include the extra cruft if not needed in the first place) + if e.SqlState[0] != 0 { + return fmt.Sprintf("%s: %s (%s)", e.Code, e.Msg, string(e.SqlState[:])) + } + return fmt.Sprintf("%s: %s", e.Code, e.Msg) } // Status represents an error code for operations that may fail diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go b/go/adbc/driver/flightsql/cmd/testserver/main.go new file mode 100644 index 0000000000..6e0ca4ffa8 --- /dev/null +++ b/go/adbc/driver/flightsql/cmd/testserver/main.go @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// A server intended specifically for testing the Flight SQL driver. Unlike +// the upstream SQLite example, which tries to be functional, this server +// tries to be useful. + +package main + +import ( + "bytes" + "context" + "flag" + "fmt" + "log" + "net" + "os" + "strconv" + "strings" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/flight" + "github.com/apache/arrow/go/v13/arrow/flight/flightsql" + "github.com/apache/arrow/go/v13/arrow/memory" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type ExampleServer struct { + flightsql.BaseServer +} + +func (srv *ExampleServer) ClosePreparedStatement(ctx context.Context, request flightsql.ActionClosePreparedStatementRequest) error { + return nil +} + +func (srv *ExampleServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) { + result.Handle = []byte(req.GetQuery()) + return +} + +func (srv *ExampleServer) GetFlightInfoPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + if bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get")) || bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get_stream")) { + schema := arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + Schema: flight.SerializeSchema(schema, srv.Alloc), + }, nil + } + + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (srv *ExampleServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + ticket, err := flightsql.CreateStatementQueryTicket(desc.Cmd) + if err != nil { + return nil, err + } + + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { + log.Printf("DoGetPreparedStatement: %v", cmd.GetPreparedStatementHandle()) + if bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get")) { + err = status.Error(codes.InvalidArgument, "expected error") + return + } + + schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": 5}]`)) + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: rec, + Desc: nil, + Err: nil, + } + if bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get_stream")) { + ch <- flight.StreamChunk{ + Data: nil, + Desc: nil, + Err: status.Error(codes.InvalidArgument, "expected error"), + } + } + }() + out = ch + return +} + +func (srv *ExampleServer) DoGetStatement(ctx context.Context, cmd flightsql.StatementQueryTicket) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { + schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"ints": 5}]`)) + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: rec, + Desc: nil, + Err: nil, + } + }() + out = ch + return +} + +func main() { + var ( + host = flag.String("host", "localhost", "hostname to bind to") + port = flag.Int("port", 0, "port to bind to") + ) + + flag.Parse() + + srv := &ExampleServer{} + srv.Alloc = memory.DefaultAllocator + + server := flight.NewServerWithMiddleware(nil) + server.RegisterFlightService(flightsql.NewFlightServer(srv)) + if err := server.Init(net.JoinHostPort(*host, strconv.Itoa(*port))); err != nil { + log.Fatal(err) + } + server.SetShutdownOnSignals(os.Interrupt, os.Kill) + + fmt.Println("Starting testing Flight SQL Server on", server.Addr(), "...") + + if err := server.Serve(); err != nil { + log.Fatal(err) + } +} diff --git a/go/adbc/drivermgr/adbc_driver_manager.cc b/go/adbc/drivermgr/adbc_driver_manager.cc index 516bf9bbf7..e4287534df 100644 --- a/go/adbc/drivermgr/adbc_driver_manager.cc +++ b/go/adbc/drivermgr/adbc_driver_manager.cc @@ -642,8 +642,12 @@ const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream return nullptr; } auto* private_data = reinterpret_cast(stream->private_data); - return private_data->private_driver->ErrorFromArrayStream(&private_data->stream, - status); + auto* error = + private_data->private_driver->ErrorFromArrayStream(&private_data->stream, status); + if (error) { + const_cast(error)->private_driver = private_data->private_driver; + } + return error; } #define INIT_ERROR(ERROR, SOURCE) \ diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl index fc489a4016..24c15f3960 100644 --- a/go/adbc/pkg/_tmpl/driver.go.tmpl +++ b/go/adbc/pkg/_tmpl/driver.go.tmpl @@ -263,7 +263,7 @@ func (cStream *cArrayStream) maybeError() C.int { if cStream.adbcErr != nil { C.{{.Prefix}}errRelease(cStream.adbcErr) } else { - cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) } cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) diff --git a/go/adbc/pkg/_tmpl/utils.h.tmpl b/go/adbc/pkg/_tmpl/utils.h.tmpl index ce3dba9dbc..d73f4bad71 100644 --- a/go/adbc/pkg/_tmpl/utils.h.tmpl +++ b/go/adbc/pkg/_tmpl/utils.h.tmpl @@ -83,7 +83,10 @@ AdbcStatusCode {{.Prefix}}StatementSetSubstraitPlan(struct AdbcStatement* stmt, AdbcStatusCode {{.Prefix}}DriverInit(int version, void* rawDriver, struct AdbcError* err); static inline void {{.Prefix}}errRelease(struct AdbcError* error) { - error->release(error); + if (error->release) { + error->release(error); + error->release = NULL; + } } void {{.Prefix}}_release_error(struct AdbcError* error); diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go index 925fd8658e..46e096952c 100644 --- a/go/adbc/pkg/flightsql/driver.go +++ b/go/adbc/pkg/flightsql/driver.go @@ -267,7 +267,7 @@ func (cStream *cArrayStream) maybeError() C.int { if cStream.adbcErr != nil { C.FlightSQLerrRelease(cStream.adbcErr) } else { - cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) } cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) diff --git a/go/adbc/pkg/flightsql/utils.h b/go/adbc/pkg/flightsql/utils.h index e3b22fb737..fbdbe89a8a 100644 --- a/go/adbc/pkg/flightsql/utils.h +++ b/go/adbc/pkg/flightsql/utils.h @@ -153,7 +153,12 @@ AdbcStatusCode FlightSQLStatementSetSubstraitPlan(struct AdbcStatement* stmt, AdbcStatusCode FlightSQLDriverInit(int version, void* rawDriver, struct AdbcError* err); -static inline void FlightSQLerrRelease(struct AdbcError* error) { error->release(error); } +static inline void FlightSQLerrRelease(struct AdbcError* error) { + if (error->release) { + error->release(error); + error->release = NULL; + } +} void FlightSQL_release_error(struct AdbcError* error); diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go index d1c143a762..c99153ccb5 100644 --- a/go/adbc/pkg/panicdummy/driver.go +++ b/go/adbc/pkg/panicdummy/driver.go @@ -267,7 +267,7 @@ func (cStream *cArrayStream) maybeError() C.int { if cStream.adbcErr != nil { C.PanicDummyerrRelease(cStream.adbcErr) } else { - cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) } cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) diff --git a/go/adbc/pkg/panicdummy/utils.h b/go/adbc/pkg/panicdummy/utils.h index 91d8294c4e..b8db59c227 100644 --- a/go/adbc/pkg/panicdummy/utils.h +++ b/go/adbc/pkg/panicdummy/utils.h @@ -156,7 +156,10 @@ AdbcStatusCode PanicDummyStatementSetSubstraitPlan(struct AdbcStatement* stmt, AdbcStatusCode PanicDummyDriverInit(int version, void* rawDriver, struct AdbcError* err); static inline void PanicDummyerrRelease(struct AdbcError* error) { - error->release(error); + if (error->release) { + error->release(error); + error->release = NULL; + } } void PanicDummy_release_error(struct AdbcError* error); diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go index 6ca09646d4..4804e32e38 100644 --- a/go/adbc/pkg/snowflake/driver.go +++ b/go/adbc/pkg/snowflake/driver.go @@ -267,7 +267,7 @@ func (cStream *cArrayStream) maybeError() C.int { if cStream.adbcErr != nil { C.SnowflakeerrRelease(cStream.adbcErr) } else { - cStream.adbcErr = (*C.struct_AdbcError)(C.malloc(C.ADBC_ERROR_1_1_0_SIZE)) + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) } cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) diff --git a/go/adbc/pkg/snowflake/utils.h b/go/adbc/pkg/snowflake/utils.h index 23391dfd70..c679316232 100644 --- a/go/adbc/pkg/snowflake/utils.h +++ b/go/adbc/pkg/snowflake/utils.h @@ -153,7 +153,12 @@ AdbcStatusCode SnowflakeStatementSetSubstraitPlan(struct AdbcStatement* stmt, AdbcStatusCode SnowflakeDriverInit(int version, void* rawDriver, struct AdbcError* err); -static inline void SnowflakeerrRelease(struct AdbcError* error) { error->release(error); } +static inline void SnowflakeerrRelease(struct AdbcError* error) { + if (error->release) { + error->release(error); + error->release = NULL; + } +} void Snowflake_release_error(struct AdbcError* error); diff --git a/python/adbc_driver_flightsql/tests/conftest.py b/python/adbc_driver_flightsql/tests/conftest.py index 4ca9508d07..b4eb181105 100644 --- a/python/adbc_driver_flightsql/tests/conftest.py +++ b/python/adbc_driver_flightsql/tests/conftest.py @@ -71,3 +71,13 @@ def dremio_dbapi(dremio_uri, dremio_user, dremio_pass): }, ) as conn: yield conn + + +@pytest.fixture +def test_dbapi(): + uri = os.environ.get("ADBC_TEST_FLIGHTSQL_URI") + if not uri: + pytest.skip("Set ADBC_TEST_FLIGHTSQL_URI to run tests") + + with adbc_driver_flightsql.dbapi.connect(uri) as conn: + yield conn diff --git a/python/adbc_driver_flightsql/tests/test_dbapi.py b/python/adbc_driver_flightsql/tests/test_dbapi.py index 0918fc7a93..e199035473 100644 --- a/python/adbc_driver_flightsql/tests/test_dbapi.py +++ b/python/adbc_driver_flightsql/tests/test_dbapi.py @@ -33,6 +33,21 @@ def test_query_error(dremio_dbapi): assert exc.args[0].startswith("INVALID_ARGUMENT: [FlightSQL] ") +def test_query_error_fetch(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("error_do_get") + with pytest.raises(Exception, match="expected error"): + cur.fetch_arrow_table() + + +def test_query_error_stream(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("error_do_get_stream") + with pytest.raises(Exception, match="expected error"): + cur.fetchone() + cur.fetchone() + + def test_query_trivial(dremio_dbapi): with dremio_dbapi.cursor() as cur: cur.execute("SELECT 1") From 6285a8aa926f2c463588aef2b032b89ae08390d9 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 29 Aug 2023 14:52:55 -0400 Subject: [PATCH 10/20] build(r/adbcsnowflake): add -lresolv (#1006) Required on Go 1.20+ since it no longer links this by default. Fixes #995. --- r/adbcflightsql/src/Makevars.in | 2 +- r/adbcsnowflake/configure | 2 +- r/adbcsnowflake/src/Makevars.in | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/r/adbcflightsql/src/Makevars.in b/r/adbcflightsql/src/Makevars.in index 07423e73ed..eb74ef3356 100644 --- a/r/adbcflightsql/src/Makevars.in +++ b/r/adbcflightsql/src/Makevars.in @@ -27,4 +27,4 @@ all: $(SHLIB) $(SHLIB): gostatic gostatic: - (cd "$(CURDIR)/go/adbc"; CC="$(CGO_CC)" CXX="$(CGO_CXX)" CGO_CFLAGS="$(CGO_CFLAGS)" "@gobin@" build -v -tags driverlib -o $(CURDIR)/go/libadbc_driver_flightsql.a -buildmode=c-archive "./pkg/flightsql") + (cd "$(CURDIR)/go/adbc"; CC="$(CGO_CC)" CXX="$(CGO_CXX)" CGO_CFLAGS="$(CGO_CFLAGS)" CGO_LDFLAGS="$(PKG_LIBS)" "@gobin@" build -v -tags driverlib -o $(CURDIR)/go/libadbc_driver_flightsql.a -buildmode=c-archive "./pkg/flightsql") diff --git a/r/adbcsnowflake/configure b/r/adbcsnowflake/configure index 095d0cf4e3..2f3ae66426 100755 --- a/r/adbcsnowflake/configure +++ b/r/adbcsnowflake/configure @@ -73,7 +73,7 @@ fi # On OSX we need -framework Security because of some dependency somewhere if [ `uname` = "Darwin" ]; then - PKG_LIBS="-framework Security -lresolv $PKG_LIBS" + PKG_LIBS="-framework Security $PKG_LIBS" fi PKG_LIBS="$PKG_LIBS $SYMBOL_ARGS" diff --git a/r/adbcsnowflake/src/Makevars.in b/r/adbcsnowflake/src/Makevars.in index f698373532..aa732f5892 100644 --- a/r/adbcsnowflake/src/Makevars.in +++ b/r/adbcsnowflake/src/Makevars.in @@ -16,7 +16,7 @@ # under the License. PKG_CPPFLAGS=-I$(CURDIR)/src -DADBC_EXPORT="" -PKG_LIBS=-L$(CURDIR)/go -ladbc_driver_snowflake @libs@ +PKG_LIBS=-L$(CURDIR)/go -ladbc_driver_snowflake -lresolv @libs@ CGO_CC = @cc@ CGO_CXX = @cxx@ @@ -27,4 +27,4 @@ all: $(SHLIB) $(SHLIB): gostatic gostatic: - (cd "$(CURDIR)/go/adbc"; CC="$(CGO_CC)" CXX="$(CGO_CXX)" CGO_CFLAGS="$(CGO_CFLAGS)" "@gobin@" build -v -tags driverlib -o $(CURDIR)/go/libadbc_driver_snowflake.a -buildmode=c-archive "./pkg/snowflake") + (cd "$(CURDIR)/go/adbc"; CC="$(CGO_CC)" CXX="$(CGO_CXX)" CGO_CFLAGS="$(CGO_CFLAGS)" CGO_LDFLAGS="$(PKG_LIBS)" "@gobin@" build -v -tags driverlib -o $(CURDIR)/go/libadbc_driver_snowflake.a -buildmode=c-archive "./pkg/snowflake") From a4bf0bfa6e2ba75cd7580acf2798c122ff2a35b1 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 30 Aug 2023 11:42:01 -0300 Subject: [PATCH 11/20] chore(r): Package-level CRAN preparation chores (#1013) For #1009, #1010, #1011, and #1012. Ideally I would have done these chores before 0.6.0; however, because three of these are first-time CRAN submissions that require a human, there will probably be a few more PRs that need to be picked into that branch. --- dev/release/rat_exclude_files.txt | 1 + r/adbcdrivermanager/NEWS.md | 7 +++++++ r/adbcdrivermanager/cran-comments.md | 7 ++++--- r/adbcflightsql/.Rbuildignore | 1 + r/adbcflightsql/NEWS.md | 3 +++ r/adbcflightsql/README.Rmd | 6 ++++++ r/adbcflightsql/README.md | 7 +++++++ r/adbcflightsql/cran-comments.md | 9 +++++++++ r/adbcpostgresql/.Rbuildignore | 1 + r/adbcpostgresql/NEWS.md | 3 +++ r/adbcpostgresql/README.Rmd | 6 ++++++ r/adbcpostgresql/README.md | 7 +++++++ r/adbcpostgresql/cran-comments.md | 9 +++++++++ r/adbcsqlite/.Rbuildignore | 1 + r/adbcsqlite/NEWS.md | 3 +++ r/adbcsqlite/README.Rmd | 6 ++++++ r/adbcsqlite/README.md | 7 +++++++ r/adbcsqlite/cran-comments.md | 9 +++++++++ 18 files changed, 90 insertions(+), 3 deletions(-) create mode 100644 r/adbcdrivermanager/NEWS.md create mode 100644 r/adbcflightsql/NEWS.md create mode 100644 r/adbcflightsql/cran-comments.md create mode 100644 r/adbcpostgresql/NEWS.md create mode 100644 r/adbcpostgresql/cran-comments.md create mode 100644 r/adbcsqlite/NEWS.md create mode 100644 r/adbcsqlite/cran-comments.md diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 644ab50702..9e95dcfcff 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -25,6 +25,7 @@ python/*/*/py.typed rat.txt r/*/DESCRIPTION r/*/NAMESPACE +r/*/NEWS.md r/*/cran-comments.md r/*/.Rbuildignore r/*/tests/testthat/_snaps/* diff --git a/r/adbcdrivermanager/NEWS.md b/r/adbcdrivermanager/NEWS.md new file mode 100644 index 0000000000..0f69dca35f --- /dev/null +++ b/r/adbcdrivermanager/NEWS.md @@ -0,0 +1,7 @@ +# adbcdrivermanager 0.6.0 + +- **r**: Ensure that info_codes are coerced to integer (#986) + +# adbcdrivermanager 0.5.0 + +* Added a `NEWS.md` file to track changes to the package. diff --git a/r/adbcdrivermanager/cran-comments.md b/r/adbcdrivermanager/cran-comments.md index e3b0ab85ab..906efef731 100644 --- a/r/adbcdrivermanager/cran-comments.md +++ b/r/adbcdrivermanager/cran-comments.md @@ -1,6 +1,7 @@ -## R CMD check results +An update to reflect the updated upstream sources for the latest +Apache Arrow ADBC libraries version. -0 errors | 0 warnings | 1 note +## R CMD check results -* This is a new release. +0 errors | 0 warnings | 0 notes diff --git a/r/adbcflightsql/.Rbuildignore b/r/adbcflightsql/.Rbuildignore index 389c04f592..1d0408c555 100644 --- a/r/adbcflightsql/.Rbuildignore +++ b/r/adbcflightsql/.Rbuildignore @@ -10,3 +10,4 @@ ^_pkgdown\.yml$ ^docs$ ^pkgdown$ +^cran-comments\.md$ diff --git a/r/adbcflightsql/NEWS.md b/r/adbcflightsql/NEWS.md new file mode 100644 index 0000000000..a2980eac97 --- /dev/null +++ b/r/adbcflightsql/NEWS.md @@ -0,0 +1,3 @@ +# adbcflightsql 0.6.0 + +* Initial CRAN submission. diff --git a/r/adbcflightsql/README.Rmd b/r/adbcflightsql/README.Rmd index 759e89926b..97f0e2a5e8 100644 --- a/r/adbcflightsql/README.Rmd +++ b/r/adbcflightsql/README.Rmd @@ -40,6 +40,12 @@ to the Arrow Database Connectivity (ADBC) FlightSQL driver. ## Installation +You can install the released version of adbcflightsql from [CRAN](https://cran.r-project.org/) with: + +```r +install.packages("adbcflightsql") +``` + You can install the development version of adbcflightsql from [GitHub](https://github.com/) with: ``` r diff --git a/r/adbcflightsql/README.md b/r/adbcflightsql/README.md index e61bfc02c9..352529424f 100644 --- a/r/adbcflightsql/README.md +++ b/r/adbcflightsql/README.md @@ -27,6 +27,13 @@ interface to the Arrow Database Connectivity (ADBC) FlightSQL driver. ## Installation +You can install the released version of adbcflightsql from +[CRAN](https://cran.r-project.org/) with: + +``` r +install.packages("adbcflightsql") +``` + You can install the development version of adbcflightsql from [GitHub](https://github.com/) with: diff --git a/r/adbcflightsql/cran-comments.md b/r/adbcflightsql/cran-comments.md new file mode 100644 index 0000000000..2d2bbbdc12 --- /dev/null +++ b/r/adbcflightsql/cran-comments.md @@ -0,0 +1,9 @@ + +An update to reflect the updated upstream sources for the latest +Apache Arrow ADBC libraries version. + +## R CMD check results + +0 errors | 0 warnings | 1 note + +* This is a new release. diff --git a/r/adbcpostgresql/.Rbuildignore b/r/adbcpostgresql/.Rbuildignore index 70ce666548..2525a88aa2 100644 --- a/r/adbcpostgresql/.Rbuildignore +++ b/r/adbcpostgresql/.Rbuildignore @@ -13,3 +13,4 @@ ^_pkgdown\.yml$ ^docs$ ^pkgdown$ +^cran-comments\.md$ diff --git a/r/adbcpostgresql/NEWS.md b/r/adbcpostgresql/NEWS.md new file mode 100644 index 0000000000..842387e43e --- /dev/null +++ b/r/adbcpostgresql/NEWS.md @@ -0,0 +1,3 @@ +# adbcpostgresql 0.6.0 + +* Initial CRAN submission. diff --git a/r/adbcpostgresql/README.Rmd b/r/adbcpostgresql/README.Rmd index d34b37bec1..cc68887b24 100644 --- a/r/adbcpostgresql/README.Rmd +++ b/r/adbcpostgresql/README.Rmd @@ -40,6 +40,12 @@ to the Arrow Database Connectivity (ADBC) PostgreSQL driver. ## Installation +You can install the released version of adbcpostgresql from [CRAN](https://cran.r-project.org/) with: + +```r +install.packages("adbcpostgresql") +``` + You can install the development version of adbcpostgresql from [GitHub](https://github.com/) with: ``` r diff --git a/r/adbcpostgresql/README.md b/r/adbcpostgresql/README.md index b91d2e3ecf..b902cc8c19 100644 --- a/r/adbcpostgresql/README.md +++ b/r/adbcpostgresql/README.md @@ -27,6 +27,13 @@ interface to the Arrow Database Connectivity (ADBC) PostgreSQL driver. ## Installation +You can install the released version of adbcpostgresql from +[CRAN](https://cran.r-project.org/) with: + +``` r +install.packages("adbcpostgresql") +``` + You can install the development version of adbcpostgresql from [GitHub](https://github.com/) with: diff --git a/r/adbcpostgresql/cran-comments.md b/r/adbcpostgresql/cran-comments.md new file mode 100644 index 0000000000..2d2bbbdc12 --- /dev/null +++ b/r/adbcpostgresql/cran-comments.md @@ -0,0 +1,9 @@ + +An update to reflect the updated upstream sources for the latest +Apache Arrow ADBC libraries version. + +## R CMD check results + +0 errors | 0 warnings | 1 note + +* This is a new release. diff --git a/r/adbcsqlite/.Rbuildignore b/r/adbcsqlite/.Rbuildignore index e96b9e144c..7a0fc58e1c 100644 --- a/r/adbcsqlite/.Rbuildignore +++ b/r/adbcsqlite/.Rbuildignore @@ -10,3 +10,4 @@ ^_pkgdown\.yml$ ^docs$ ^pkgdown$ +^cran-comments\.md$ diff --git a/r/adbcsqlite/NEWS.md b/r/adbcsqlite/NEWS.md new file mode 100644 index 0000000000..f77c65de3c --- /dev/null +++ b/r/adbcsqlite/NEWS.md @@ -0,0 +1,3 @@ +# adbcsqlite 0.6.0 + +* Initial CRAN submission. diff --git a/r/adbcsqlite/README.Rmd b/r/adbcsqlite/README.Rmd index a59c1ffc03..ffcaee358e 100644 --- a/r/adbcsqlite/README.Rmd +++ b/r/adbcsqlite/README.Rmd @@ -40,6 +40,12 @@ to the Arrow Database Connectivity (ADBC) SQLite driver. ## Installation +You can install the released version of adbcsqlite from [CRAN](https://cran.r-project.org/) with: + +```r +install.packages("adbcsqlite") +``` + You can install the development version of adbcsqlite from [GitHub](https://github.com/) with: ``` r diff --git a/r/adbcsqlite/README.md b/r/adbcsqlite/README.md index d5d8eb7342..f04f07cbee 100644 --- a/r/adbcsqlite/README.md +++ b/r/adbcsqlite/README.md @@ -27,6 +27,13 @@ interface to the Arrow Database Connectivity (ADBC) SQLite driver. ## Installation +You can install the released version of adbcsqlite from +[CRAN](https://cran.r-project.org/) with: + +``` r +install.packages("adbcsqlite") +``` + You can install the development version of adbcsqlite from [GitHub](https://github.com/) with: diff --git a/r/adbcsqlite/cran-comments.md b/r/adbcsqlite/cran-comments.md new file mode 100644 index 0000000000..2d2bbbdc12 --- /dev/null +++ b/r/adbcsqlite/cran-comments.md @@ -0,0 +1,9 @@ + +An update to reflect the updated upstream sources for the latest +Apache Arrow ADBC libraries version. + +## R CMD check results + +0 errors | 0 warnings | 1 note + +* This is a new release. From 587e25f9755d2cbbd755c944f1245b087ab97844 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 30 Aug 2023 12:48:29 -0300 Subject: [PATCH 12/20] chore(r): Ensure package documentation has `\alias` (#1014) Remotes a CMD check NOTE on R-devel: ``` * checking Rd metadata ... NOTE Rd files without \alias: 'adbcsqlite-package.Rd' ``` I had used `@aliases NULL` because each driver package has an identically named function and using `@aliases NULL` ensured that there were no duplicate aliases. Recent r-devel added a check for .Rd files with no alias, hence the need for an update. For #1010, #1011, and #1012. --- r/adbcflightsql/R/adbcflightsql-package.R | 2 +- r/adbcflightsql/man/adbcflightsql-package.Rd | 2 ++ r/adbcpostgresql/R/adbcpostgresql-package.R | 2 +- r/adbcpostgresql/man/adbcpostgresql-package.Rd | 2 ++ r/adbcsnowflake/R/adbcsnowflake-package.R | 2 +- r/adbcsnowflake/man/adbcsnowflake-package.Rd | 2 ++ r/adbcsqlite/R/adbcsqlite-package.R | 2 +- r/adbcsqlite/man/adbcsqlite-package.Rd | 2 ++ 8 files changed, 12 insertions(+), 4 deletions(-) diff --git a/r/adbcflightsql/R/adbcflightsql-package.R b/r/adbcflightsql/R/adbcflightsql-package.R index 41c9cbe3c6..20ae6f02bd 100644 --- a/r/adbcflightsql/R/adbcflightsql-package.R +++ b/r/adbcflightsql/R/adbcflightsql-package.R @@ -16,7 +16,7 @@ # under the License. #' @keywords internal -#' @aliases NULL +#' @aliases adbcflightsql-package "_PACKAGE" ## usethis namespace: start diff --git a/r/adbcflightsql/man/adbcflightsql-package.Rd b/r/adbcflightsql/man/adbcflightsql-package.Rd index aad70825c0..1b84564545 100644 --- a/r/adbcflightsql/man/adbcflightsql-package.Rd +++ b/r/adbcflightsql/man/adbcflightsql-package.Rd @@ -2,6 +2,8 @@ % Please edit documentation in R/adbcflightsql-package.R \docType{package} \name{adbcflightsql-package} +\alias{adbcflightsql-package} +\alias{_PACKAGE} \title{adbcflightsql: 'Arrow' Database Connectivity ('ADBC') 'FlightSQL' Driver} \description{ Provides a developer-facing interface to the 'Arrow' Database Connectivity ('ADBC') 'FlightSQL' driver for the purposes of building high-level database interfaces for users. 'ADBC' \url{https://arrow.apache.org/adbc/} is an API standard for database access libraries that uses 'Arrow' for result sets and query parameters. diff --git a/r/adbcpostgresql/R/adbcpostgresql-package.R b/r/adbcpostgresql/R/adbcpostgresql-package.R index b116eeb17e..19c3e6d3ac 100644 --- a/r/adbcpostgresql/R/adbcpostgresql-package.R +++ b/r/adbcpostgresql/R/adbcpostgresql-package.R @@ -16,7 +16,7 @@ # under the License. #' @keywords internal -#' @aliases NULL +#' @aliases adbcpostgresql-package "_PACKAGE" ## usethis namespace: start diff --git a/r/adbcpostgresql/man/adbcpostgresql-package.Rd b/r/adbcpostgresql/man/adbcpostgresql-package.Rd index 33ca3933d2..366a96b5dd 100644 --- a/r/adbcpostgresql/man/adbcpostgresql-package.Rd +++ b/r/adbcpostgresql/man/adbcpostgresql-package.Rd @@ -2,6 +2,8 @@ % Please edit documentation in R/adbcpostgresql-package.R \docType{package} \name{adbcpostgresql-package} +\alias{adbcpostgresql-package} +\alias{_PACKAGE} \title{adbcpostgresql: 'Arrow' Database Connectivity ('ADBC') 'PostgreSQL' Driver} \description{ Provides a developer-facing interface to the 'Arrow' Database Connectivity ('ADBC') 'PostgreSQL' driver for the purposes of building high-level database interfaces for users. 'ADBC' \url{https://arrow.apache.org/adbc/} is an API standard for database access libraries that uses 'Arrow' for result sets and query parameters. diff --git a/r/adbcsnowflake/R/adbcsnowflake-package.R b/r/adbcsnowflake/R/adbcsnowflake-package.R index 52b2839745..89247d8562 100644 --- a/r/adbcsnowflake/R/adbcsnowflake-package.R +++ b/r/adbcsnowflake/R/adbcsnowflake-package.R @@ -16,7 +16,7 @@ # under the License. #' @keywords internal -#' @aliases NULL +#' @aliases adbcsnowflake-package "_PACKAGE" ## usethis namespace: start diff --git a/r/adbcsnowflake/man/adbcsnowflake-package.Rd b/r/adbcsnowflake/man/adbcsnowflake-package.Rd index 3a0a99993f..684acc2be4 100644 --- a/r/adbcsnowflake/man/adbcsnowflake-package.Rd +++ b/r/adbcsnowflake/man/adbcsnowflake-package.Rd @@ -2,6 +2,8 @@ % Please edit documentation in R/adbcsnowflake-package.R \docType{package} \name{adbcsnowflake-package} +\alias{adbcsnowflake-package} +\alias{_PACKAGE} \title{adbcsnowflake: Arrow Database Connectivity ('ADBC') 'Snowflake' Driver} \description{ Provides a developer-facing interface to the 'Arrow' Database Connectivity ('ADBC') 'Snowflake' driver for the purposes of building high-level database interfaces for users. 'ADBC' \url{https://arrow.apache.org/adbc/} is an API standard for database access libraries that uses 'Arrow' for result sets and query parameters. diff --git a/r/adbcsqlite/R/adbcsqlite-package.R b/r/adbcsqlite/R/adbcsqlite-package.R index f9611d28a9..8de0eb6870 100644 --- a/r/adbcsqlite/R/adbcsqlite-package.R +++ b/r/adbcsqlite/R/adbcsqlite-package.R @@ -16,7 +16,7 @@ # under the License. #' @keywords internal -#' @aliases NULL +#' @aliases adbcsqlite-package "_PACKAGE" ## usethis namespace: start diff --git a/r/adbcsqlite/man/adbcsqlite-package.Rd b/r/adbcsqlite/man/adbcsqlite-package.Rd index 4d461b927f..50635ca37f 100644 --- a/r/adbcsqlite/man/adbcsqlite-package.Rd +++ b/r/adbcsqlite/man/adbcsqlite-package.Rd @@ -2,6 +2,8 @@ % Please edit documentation in R/adbcsqlite-package.R \docType{package} \name{adbcsqlite-package} +\alias{adbcsqlite-package} +\alias{_PACKAGE} \title{adbcsqlite: 'Arrow' Database Connectivity ('ADBC') 'SQLite' Driver} \description{ Provides a developer-facing interface to the 'Arrow' Database Connectivity ('ADBC') 'SQLite' driver for the purposes of building high-level database interfaces for users. 'ADBC' \url{https://arrow.apache.org/adbc/} is an API standard for database access libraries that uses 'Arrow' for result sets and query parameters. From 114a1f810929a3dad3cc850ba5e751076116a7d4 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 31 Aug 2023 11:26:58 -0300 Subject: [PATCH 13/20] chore(c/vendor): Make __int128 safe-math header opt-in (#1016) When trying to submit adbcpostgres, I get: ``` * checking whether package 'adbcpostgresql' can be installed ... WARNING Found the following significant warnings: vendor/portable-snippets/safe-math.h:171:9: warning: ISO C++ does not support '__int128' for 'psnip_safe_int128_t' [-Wpedantic] vendor/portable-snippets/safe-math.h:172:18: warning: ISO C++ does not support '__int128' for 'psnip_safe_uint128_t' [-Wpedantic] See 'd:/RCompile/CRANguest/R-devel/adbcpostgresql.Rcheck/00install.out' for details. ``` This is the workaround that I'm including in the CRAN packaging branch for 0.6.0 but it would be nice to avoid dealing with it on subsequent submissions. The only function we use from this header is the overflow-safe int64 multiplication for timestamp support in the postgres driver...given that we're always multiplying by a fixed 1000 in that line, we could also probably hard-code the check + test the last supported date? --- c/vendor/portable-snippets/safe-math.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/c/vendor/portable-snippets/safe-math.h b/c/vendor/portable-snippets/safe-math.h index 7f6426ac76..797404ae4f 100644 --- a/c/vendor/portable-snippets/safe-math.h +++ b/c/vendor/portable-snippets/safe-math.h @@ -166,11 +166,15 @@ #define PSNIP_SAFE_IS_LARGER(ORIG_MAX, DEST_MAX) ((DEST_MAX / ORIG_MAX) >= ORIG_MAX) +// Using __int128 intrinsics causes compilation to fail with -Wpedantic +// which is required to pass CRAN incoming checks for R packages that use this header +#if defined(PSNIP_USE_INTRINSIC_INT128) #if defined(__GNUC__) && ((__GNUC__ >= 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && defined(__SIZEOF_INT128__) && !defined(__ibmxl__) #define PSNIP_SAFE_HAVE_128 typedef __int128 psnip_safe_int128_t; typedef unsigned __int128 psnip_safe_uint128_t; #endif /* defined(__GNUC__) */ +#endif #if !defined(PSNIP_SAFE_NO_FIXED) #define PSNIP_SAFE_HAVE_INT8_LARGER From 1bb507f300ee257aac878b44c858bf8d5da8b2b8 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 1 Sep 2023 09:11:52 -0300 Subject: [PATCH 14/20] fix(r/adbcdrivermanager): Make `adbc_xptr_is_valid()` return `FALSE` for external pointer to NULL (#1007) Closes #1001. In theory one could differentiate the NULL from the "non-null but invalid" case, but I think returning FALSE is the least confusing thing to do. --- r/adbcdrivermanager/src/radbc.cc | 6 +++--- r/adbcdrivermanager/src/radbc.h | 4 ++-- .../tests/testthat/test-utils.R | 20 +++++++++++++++++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/r/adbcdrivermanager/src/radbc.cc b/r/adbcdrivermanager/src/radbc.cc index 92f8e31ca8..224c7f81eb 100644 --- a/r/adbcdrivermanager/src/radbc.cc +++ b/r/adbcdrivermanager/src/radbc.cc @@ -150,7 +150,7 @@ extern "C" SEXP RAdbcMoveDatabase(SEXP database_xptr) { } extern "C" SEXP RAdbcDatabaseValid(SEXP database_xptr) { - AdbcDatabase* database = adbc_from_xptr(database_xptr); + AdbcDatabase* database = adbc_from_xptr(database_xptr, true); return Rf_ScalarLogical(database != nullptr && database->private_data != nullptr); } @@ -219,7 +219,7 @@ extern "C" SEXP RAdbcMoveConnection(SEXP connection_xptr) { } extern "C" SEXP RAdbcConnectionValid(SEXP connection_xptr) { - AdbcConnection* connection = adbc_from_xptr(connection_xptr); + AdbcConnection* connection = adbc_from_xptr(connection_xptr, true); return Rf_ScalarLogical(connection != nullptr && connection->private_data != nullptr); } @@ -396,7 +396,7 @@ extern "C" SEXP RAdbcMoveStatement(SEXP statement_xptr) { } extern "C" SEXP RAdbcStatementValid(SEXP statement_xptr) { - AdbcStatement* statement = adbc_from_xptr(statement_xptr); + AdbcStatement* statement = adbc_from_xptr(statement_xptr, true); return Rf_ScalarLogical(statement != nullptr && statement->private_data != nullptr); } diff --git a/r/adbcdrivermanager/src/radbc.h b/r/adbcdrivermanager/src/radbc.h index 73b37a3dce..9baf8d6d89 100644 --- a/r/adbcdrivermanager/src/radbc.h +++ b/r/adbcdrivermanager/src/radbc.h @@ -64,13 +64,13 @@ inline const char* adbc_xptr_class() { } template -static inline T* adbc_from_xptr(SEXP xptr) { +static inline T* adbc_from_xptr(SEXP xptr, bool null_ok = false) { if (!Rf_inherits(xptr, adbc_xptr_class())) { Rf_error("Expected external pointer with class '%s'", adbc_xptr_class()); } T* ptr = reinterpret_cast(R_ExternalPtrAddr(xptr)); - if (ptr == nullptr) { + if (!null_ok && ptr == nullptr) { Rf_error("Can't convert external pointer to NULL to T*"); } return ptr; diff --git a/r/adbcdrivermanager/tests/testthat/test-utils.R b/r/adbcdrivermanager/tests/testthat/test-utils.R index a533ba99e3..3c46c489a7 100644 --- a/r/adbcdrivermanager/tests/testthat/test-utils.R +++ b/r/adbcdrivermanager/tests/testthat/test-utils.R @@ -80,7 +80,27 @@ test_that("pointer mover leaves behind an invalid external pointer", { expect_true(adbc_xptr_is_valid(stream)) expect_true(adbc_xptr_is_valid(adbc_xptr_move(stream))) expect_false(adbc_xptr_is_valid(stream)) +}) + +test_that("adbc_xptr_is_valid() returns FALSE for null pointer", { + db <- adbc_database_init(adbc_driver_void()) + con <- adbc_connection_init(db) + stmt <- adbc_statement_init(con) + stream <- nanoarrow::basic_array_stream(list(), nanoarrow::na_na()) + + # A compact way to set the external pointer to NULL + db <- unserialize(serialize(db, NULL)) + con <- unserialize(serialize(con, NULL)) + stmt <- unserialize(serialize(stmt, NULL)) + stream <- unserialize(serialize(stream, NULL)) + + expect_false(adbc_xptr_is_valid(db)) + expect_false(adbc_xptr_is_valid(con)) + expect_false(adbc_xptr_is_valid(stmt)) + expect_false(adbc_xptr_is_valid(stream)) +}) +test_that("adbc_xptr_is_valid() errors for non-ADBC objects", { expect_error( adbc_xptr_is_valid(NULL), "must inherit from one of" From d772fd11809361f6671f2255c0b0c571816c6855 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 1 Sep 2023 10:38:18 -0400 Subject: [PATCH 15/20] fix(go/adbc/driver/snowflake): properly handle time fields (#1021) Fixes #1019 --- go/adbc/driver/snowflake/driver_test.go | 14 +++++++++++++- go/adbc/driver/snowflake/record_reader.go | 10 ++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 712f730e6c..89fc566dcf 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -333,13 +333,25 @@ func (suite *SnowflakeTests) TestSqlIngestTimestamp() { sc := arrow.NewSchema([]arrow.Field{{ Name: "col", Type: arrow.FixedWidthTypes.Timestamp_us, Nullable: true, - }}, nil) + }, { + Name: "col2", Type: arrow.FixedWidthTypes.Time64us, + Nullable: true, + }, { + Name: "col3", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + }, nil) bldr := array.NewRecordBuilder(memory.DefaultAllocator, sc) defer bldr.Release() tbldr := bldr.Field(0).(*array.TimestampBuilder) tbldr.AppendValues([]arrow.Timestamp{0, 0, 42}, []bool{false, true, true}) + tmbldr := bldr.Field(1).(*array.Time64Builder) + tmbldr.AppendValues([]arrow.Time64{420000, 0, 86000}, []bool{true, false, true}) + ibldr := bldr.Field(2).(*array.Int64Builder) + ibldr.AppendValues([]int64{-1, 25, 0}, []bool{true, true, false}) + rec := bldr.NewRecord() defer rec.Release() diff --git a/go/adbc/driver/snowflake/record_reader.go b/go/adbc/driver/snowflake/record_reader.go index db0bf0f89f..5b4dbb4960 100644 --- a/go/adbc/driver/snowflake/record_reader.go +++ b/go/adbc/driver/snowflake/record_reader.go @@ -110,9 +110,15 @@ func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader) (*arrow. } } case "TIME": - f.Type = arrow.FixedWidthTypes.Time64ns + var dt arrow.DataType + if srcMeta.Scale < 6 { + dt = &arrow.Time32Type{Unit: arrow.TimeUnit(srcMeta.Scale / 3)} + } else { + dt = &arrow.Time64Type{Unit: arrow.TimeUnit(srcMeta.Scale / 3)} + } + f.Type = dt transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { - return compute.CastArray(ctx, a, compute.SafeCastOptions(f.Type)) + return compute.CastArray(ctx, a, compute.SafeCastOptions(dt)) } case "TIMESTAMP_NTZ": dt := &arrow.TimestampType{Unit: arrow.TimeUnit(srcMeta.Scale / 3)} From 932b721c463ba43899b2abe6fabea8e0386d4327 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Sep 2023 10:57:26 -0400 Subject: [PATCH 16/20] fix(c/driver/sqlite): escape table names in INSERT, too (#1003) Fixes #1000. --- c/driver/sqlite/sqlite.c | 57 +++++++++++++++++----------------- c/driver/sqlite/sqlite_test.cc | 31 ++++++++++++++++++ 2 files changed, 59 insertions(+), 29 deletions(-) diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c index 7b4072c2ba..834552e0aa 100644 --- a/c/driver/sqlite/sqlite.c +++ b/c/driver/sqlite/sqlite.c @@ -1081,26 +1081,28 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, sqlite3_str* create_query = sqlite3_str_new(NULL); if (sqlite3_str_errcode(create_query)) { SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + sqlite3_free(sqlite3_str_finish(create_query)); return ADBC_STATUS_INTERNAL; } - struct StringBuilder insert_query = {0}; - if (StringBuilderInit(&insert_query, /*initial_size=*/256) != 0) { - SetError(error, "[SQLite] Could not initiate StringBuilder"); + sqlite3_str* insert_query = sqlite3_str_new(NULL); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); sqlite3_free(sqlite3_str_finish(create_query)); + sqlite3_free(sqlite3_str_finish(insert_query)); return ADBC_STATUS_INTERNAL; } - sqlite3_str_appendf(create_query, "%s%Q%s", "CREATE TABLE ", stmt->target_table, " ("); + sqlite3_str_appendf(create_query, "CREATE TABLE %Q (", stmt->target_table); if (sqlite3_str_errcode(create_query)) { - SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } - if (StringBuilderAppend(&insert_query, "%s%s%s", "INSERT INTO ", stmt->target_table, - " VALUES (") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + sqlite3_str_appendf(insert_query, "INSERT INTO %Q VALUES (", stmt->target_table); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1111,7 +1113,8 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, if (i > 0) { sqlite3_str_appendf(create_query, "%s", ", "); if (sqlite3_str_errcode(create_query)) { - SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + SetError(error, "[SQLite] Failed to build CREATE: %s", + sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1119,7 +1122,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, sqlite3_str_appendf(create_query, "%Q", stmt->binder.schema.children[i]->name); if (sqlite3_str_errcode(create_query)) { - SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1127,7 +1130,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, int status = ArrowSchemaViewInit(&view, stmt->binder.schema.children[i], &arrow_error); if (status != 0) { - SetError(error, "Failed to parse schema for column %d: %s (%d): %s", i, + SetError(error, "[SQLite] Failed to parse schema for column %d: %s (%d): %s", i, strerror(status), status, arrow_error.message); code = ADBC_STATUS_INTERNAL; goto cleanup; @@ -1160,16 +1163,9 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, break; } - if (i > 0) { - if (StringBuilderAppend(&insert_query, "%s", ", ") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); - code = ADBC_STATUS_INTERNAL; - goto cleanup; - } - } - - if (StringBuilderAppend(&insert_query, "%s", "?") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : "")); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1177,13 +1173,14 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, sqlite3_str_appendchar(create_query, 1, ')'); if (sqlite3_str_errcode(create_query)) { - SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } - if (StringBuilderAppend(&insert_query, "%s", ")") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + sqlite3_str_appendchar(insert_query, 1, ')'); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1207,11 +1204,13 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, } if (code == ADBC_STATUS_OK) { - int rc = sqlite3_prepare_v2(stmt->conn, insert_query.buffer, (int)insert_query.size, - insert_statement, /*pzTail=*/NULL); + int rc = sqlite3_prepare_v2(stmt->conn, sqlite3_str_value(insert_query), + sqlite3_str_length(insert_query), insert_statement, + /*pzTail=*/NULL); if (rc != SQLITE_OK) { - SetError(error, "[SQLite] Failed to prepare statement: %s (executed '%s')", - sqlite3_errmsg(stmt->conn), insert_query.buffer); + SetError(error, "[SQLite] Failed to prepare statement: %s (executed '%.*s')", + sqlite3_errmsg(stmt->conn), sqlite3_str_length(insert_query), + sqlite3_str_value(insert_query)); code = ADBC_STATUS_INTERNAL; } } @@ -1220,7 +1219,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, cleanup: sqlite3_free(sqlite3_str_finish(create_query)); - StringBuilderReset(&insert_query); + sqlite3_free(sqlite3_str_finish(insert_query)); return code; } diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc index d2821a306a..e5234b9a12 100644 --- a/c/driver/sqlite/sqlite_test.cc +++ b/c/driver/sqlite/sqlite_test.cc @@ -253,6 +253,37 @@ class SqliteStatementTest : public ::testing::Test, }; ADBCV_TEST_STATEMENT(SqliteStatementTest) +TEST_F(SqliteStatementTest, SqlIngestNameEscaping) { + ASSERT_THAT(quirks()->DropTable(&connection, "\"test-table\"", &error), + adbc_validation::IsOkStatus(&error)); + + std::string table = "test-table"; + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + ASSERT_THAT( + adbc_validation::MakeSchema(&schema.value, {{"index", NANOARROW_TYPE_INT64}, + {"create", NANOARROW_TYPE_STRING}}), + adbc_validation::IsOkErrno()); + ASSERT_THAT((adbc_validation::MakeBatch( + &schema.value, &array.value, &na_error, {42, -42, std::nullopt}, + {"foo", std::nullopt, ""})), + adbc_validation::IsOkErrno(&na_error)); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + table.c_str(), &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + adbc_validation::IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_EQ(3, rows_affected); +} + // -- SQLite Specific Tests ------------------------------------------ constexpr size_t kInferRows = 16; From 3d1d6ccfa9b2e03fd98fc931a4fe8b4de59cc376 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Sep 2023 13:19:23 -0400 Subject: [PATCH 17/20] fix(c/driver/postgresql): suppress console spam (#1027) Fixes #1023. --- c/driver/postgresql/connection.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc index c106ffca78..d20b3ae1f0 100644 --- a/c/driver/postgresql/connection.cc +++ b/c/driver/postgresql/connection.cc @@ -744,6 +744,10 @@ class PqGetObjectsHelper { struct ArrowArray* fk_column_name_col_; }; +// A notice processor that does nothing with notices. In the future we can log +// these, but this suppresses the default of printing to stderr. +void SilentNoticeProcessor(void* /*arg*/, const char* /*message*/) {} + } // namespace AdbcStatusCode PostgresConnection::Cancel(struct AdbcError* error) { @@ -1398,12 +1402,17 @@ AdbcStatusCode PostgresConnection::Init(struct AdbcDatabase* database, database_ = *reinterpret_cast*>(database->private_data); type_resolver_ = database_->type_resolver(); + RAISE_ADBC(database_->Connect(&conn_, error)); + cancel_ = PQgetCancel(conn_); if (!cancel_) { SetError(error, "[libpq] Could not initialize PGcancel"); return ADBC_STATUS_UNKNOWN; } + + std::ignore = PQsetNoticeProcessor(conn_, SilentNoticeProcessor, nullptr); + return ADBC_STATUS_OK; } From 9d624c610b9891d09f4b3736b9b633e50a3b7f7e Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Sep 2023 13:56:27 -0400 Subject: [PATCH 18/20] docs: describe how to use SSO with Snowflake (#1030) Fixes #841. --------- Co-authored-by: Dewey Dunnington --- docs/source/driver/snowflake.rst | 57 ++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/docs/source/driver/snowflake.rst b/docs/source/driver/snowflake.rst index 3e17b5b060..52db342822 100644 --- a/docs/source/driver/snowflake.rst +++ b/docs/source/driver/snowflake.rst @@ -165,7 +165,60 @@ password can be provided in the URI or via the ``username`` and ``password`` options to the :cpp:class:`AdbcDatabase`. Alternately, other types of authentication can be specified and customized. -See "Client Options" below. +See "Client Options" below for details on all the options. + +SSO Authentication +~~~~~~~~~~~~~~~~~~ + +Snowflake supports `single sign-on +`_. +If your account has been configured with SSO, it can be used with the +Snowflake driver by setting the following options when constructing the +:cpp:class:`AdbcDatabase`: + +- ``adbc.snowflake.sql.account``: your Snowflake account. (For example, if + you log in to ``https://foobar.snowflakecomputing.com``, then your account + identifier is ``foobar``.) +- ``adbc.snowflake.sql.auth_type``: ``auth_ext_browser``. +- ``username``: your username. (This should probably be your email, + e.g. ``jdoe@example.com``.) + +A new browser tab or window should appear where you can continue the login. +Once this is complete, you will have a complete ADBC database/connection +object. Some users have reported needing other configuration options, such as +``adbc.snowflake.sql.region`` and ``adbc.snowflake.sql.uri.*`` (see below for +a listing). + +.. tab-set:: + + .. tab-item:: Python + :sync: python + + .. code-block:: python + + import adbc_driver_snowflake.dbapi + # This will open a new browser tab, and block until you log in. + adbc_driver_snowflake.dbapi.connect(db_kwargs={ + "adbc.snowflake.sql.account": "foobar", + "adbc.snowflake.sql.auth_type": "auth_ext_browser", + "username": "jdoe@example.com", + }) + + .. tab-item:: R + :sync: r + + .. code-block:: r + + library(adbcdrivermanager) + db <- adbc_database_init( + adbcsnowflake::adbcsnowflake(), + adbc.snowflake.sql.account = 'foobar', + adbc.snowflake.sql.auth_type = 'auth_ext_browser' + username = 'jdoe@example.com', + ) + # This will open a new browser tab, and block until you log in. + con <- adbc_connection_init(db) + Bulk Ingestion -------------- @@ -198,7 +251,7 @@ In addition, the current database and schema for the session must be set. If these are not set, the ``CREATE TEMPORARY STAGE`` command executed by the driver can fail with the following error: -.. code-block:: +.. code-block:: sql CREATE TEMPORARY STAGE SYSTEM$BIND file_format=(type=csv field_optionally_enclosed_by='"') CANNOT perform CREATE STAGE. This session does not have a current schema. Call 'USE SCHEMA' or use a qualified name. From 1bc8a7faaab630e13c0d03bb3530acc1c7b37fc5 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Sep 2023 13:56:41 -0400 Subject: [PATCH 19/20] fix(c/driver_manager): fix crash when error is null (#1029) Fixes #1028. --- c/driver_manager/adbc_driver_manager.cc | 3 ++- c/driver_manager/adbc_driver_manager_test.cc | 15 +++++++++++++++ go/adbc/drivermgr/adbc_driver_manager.cc | 3 ++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index e4287534df..c28bea931f 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -651,7 +651,8 @@ const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream } #define INIT_ERROR(ERROR, SOURCE) \ - if ((ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ + if ((ERROR) != nullptr && \ + (ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ (ERROR)->private_driver = (SOURCE)->private_driver; \ } diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc index 5e70390388..58d056c499 100644 --- a/c/driver_manager/adbc_driver_manager_test.cc +++ b/c/driver_manager/adbc_driver_manager_test.cc @@ -33,6 +33,7 @@ std::string AdbcDriverManagerDefaultEntrypoint(const std::string& filename); namespace adbc { +using adbc_validation::Handle; using adbc_validation::IsOkStatus; using adbc_validation::IsStatus; @@ -227,6 +228,20 @@ class SqliteDatabaseTest : public ::testing::Test, public adbc_validation::Datab }; ADBCV_TEST_DATABASE(SqliteDatabaseTest) +TEST_F(SqliteDatabaseTest, NullError) { + Handle conn; + + ASSERT_THAT(AdbcDatabaseNew(&database, nullptr), IsOkStatus()); + ASSERT_THAT(quirks()->SetupDatabase(&database, nullptr), IsOkStatus()); + ASSERT_THAT(AdbcDatabaseInit(&database, nullptr), IsOkStatus()); + + ASSERT_THAT(AdbcConnectionNew(&conn.value, nullptr), IsOkStatus()); + ASSERT_THAT(AdbcConnectionInit(&conn.value, &database, nullptr), IsOkStatus()); + ASSERT_THAT(AdbcConnectionRelease(&conn.value, nullptr), IsOkStatus()); + + ASSERT_THAT(AdbcDatabaseRelease(&database, nullptr), IsOkStatus()); +} + class SqliteConnectionTest : public ::testing::Test, public adbc_validation::ConnectionTest { public: diff --git a/go/adbc/drivermgr/adbc_driver_manager.cc b/go/adbc/drivermgr/adbc_driver_manager.cc index e4287534df..c28bea931f 100644 --- a/go/adbc/drivermgr/adbc_driver_manager.cc +++ b/go/adbc/drivermgr/adbc_driver_manager.cc @@ -651,7 +651,8 @@ const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream } #define INIT_ERROR(ERROR, SOURCE) \ - if ((ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ + if ((ERROR) != nullptr && \ + (ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ (ERROR)->private_driver = (SOURCE)->private_driver; \ } From 2d27f17d8d67888e1ceaf97915eabfe6cd5dc2b4 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Sep 2023 13:56:55 -0400 Subject: [PATCH 20/20] fix(c/driver): return NOT_FOUND for GetTableSchema (#1026) Fixes #1022. --- c/driver/postgresql/connection.cc | 8 +++++++- c/driver/postgresql/statement.cc | 1 + c/driver/sqlite/sqlite.c | 5 +++-- c/integration/duckdb/duckdb_test.cc | 1 + c/validation/adbc_validation.cc | 13 +++++++++++++ c/validation/adbc_validation.h | 4 ++++ go/adbc/driver/snowflake/driver.go | 4 ++++ .../adbc_driver_postgresql/tests/test_lowlevel.py | 5 +++++ python/adbc_driver_sqlite/tests/test_lowlevel.py | 5 +++++ 9 files changed, 43 insertions(+), 3 deletions(-) diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc index d20b3ae1f0..de37d84c48 100644 --- a/c/driver/postgresql/connection.cc +++ b/c/driver/postgresql/connection.cc @@ -1312,7 +1312,13 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, StringBuilderReset(&query); RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); + auto result = result_helper.Execute(); + if (result != ADBC_STATUS_OK) { + if (std::string(error->sqlstate, 5) == "42P01") { + return ADBC_STATUS_NOT_FOUND; + } + return result; + } auto uschema = nanoarrow::UniqueSchema(); ArrowSchemaInit(uschema.get()); diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 6ea52b53ea..c1aaa1f63e 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -651,6 +651,7 @@ int TupleReader::GetNext(struct ArrowArray* out) { struct ArrowArray tmp; NANOARROW_RETURN_NOT_OK(BuildOutput(&tmp, &error)); + PQclear(result_); // Check the server-side response result_ = PQgetResult(conn_); const ExecStatusType pq_status = PQresultStatus(result_); diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c index 834552e0aa..5678a06451 100644 --- a/c/driver/sqlite/sqlite.c +++ b/c/driver/sqlite/sqlite.c @@ -877,6 +877,7 @@ AdbcStatusCode SqliteConnectionGetTableSchema(struct AdbcConnection* connection, return ADBC_STATUS_INTERNAL; } + // TODO(apache/arrow-adbc#1025): escape if (StringBuilderAppend(&query, "%s%s", "SELECT * FROM ", table_name) != 0) { StringBuilderReset(&query); SetError(error, "[SQLite] Call to StringBuilderAppend failed"); @@ -888,8 +889,8 @@ AdbcStatusCode SqliteConnectionGetTableSchema(struct AdbcConnection* connection, sqlite3_prepare_v2(conn->conn, query.buffer, query.size, &stmt, /*pzTail=*/NULL); StringBuilderReset(&query); if (rc != SQLITE_OK) { - SetError(error, "[SQLite] Failed to prepare query: %s", sqlite3_errmsg(conn->conn)); - return ADBC_STATUS_INTERNAL; + SetError(error, "[SQLite] GetTableSchema: %s", sqlite3_errmsg(conn->conn)); + return ADBC_STATUS_NOT_FOUND; } struct ArrowArrayStream stream = {0}; diff --git a/c/integration/duckdb/duckdb_test.cc b/c/integration/duckdb/duckdb_test.cc index c42c3813cf..a373abd888 100644 --- a/c/integration/duckdb/duckdb_test.cc +++ b/c/integration/duckdb/duckdb_test.cc @@ -75,6 +75,7 @@ class DuckDbConnectionTest : public ::testing::Test, void TestAutocommitDefault() { GTEST_SKIP(); } void TestMetadataGetTableSchema() { GTEST_SKIP(); } + void TestMetadataGetTableSchemaNotFound() { GTEST_SKIP(); } void TestMetadataGetTableTypes() { GTEST_SKIP(); } protected: diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc index afb0260a63..81e9c7bc0c 100644 --- a/c/validation/adbc_validation.cc +++ b/c/validation/adbc_validation.cc @@ -430,6 +430,19 @@ void ConnectionTest::TestMetadataGetTableSchema() { {"strings", NANOARROW_TYPE_STRING, NULLABLE}})); } +void ConnectionTest::TestMetadataGetTableSchemaNotFound() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + ASSERT_THAT(quirks()->DropTable(&connection, "thistabledoesnotexist", &error), + IsOkStatus(&error)); + + Handle schema; + ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, + /*db_schema=*/nullptr, "thistabledoesnotexist", + &schema.value, &error), + IsStatus(ADBC_STATUS_NOT_FOUND, &error)); +} + void ConnectionTest::TestMetadataGetTableTypes() { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index 9bedc6a376..a8140ac103 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -188,6 +188,7 @@ class ConnectionTest { void TestMetadataGetInfo(); void TestMetadataGetTableSchema(); + void TestMetadataGetTableSchemaNotFound(); void TestMetadataGetTableTypes(); void TestMetadataGetObjectsCatalogs(); @@ -219,6 +220,9 @@ class ConnectionTest { TEST_F(FIXTURE, MetadataCurrentDbSchema) { TestMetadataCurrentDbSchema(); } \ TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); } \ TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); } \ + TEST_F(FIXTURE, MetadataGetTableSchemaNotFound) { \ + TestMetadataGetTableSchemaNotFound(); \ + } \ TEST_F(FIXTURE, MetadataGetTableTypes) { TestMetadataGetTableTypes(); } \ TEST_F(FIXTURE, MetadataGetObjectsCatalogs) { TestMetadataGetObjectsCatalogs(); } \ TEST_F(FIXTURE, MetadataGetObjectsDbSchemas) { TestMetadataGetObjectsDbSchemas(); } \ diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go index a00513817b..1f74ef9b94 100644 --- a/go/adbc/driver/snowflake/driver.go +++ b/go/adbc/driver/snowflake/driver.go @@ -163,6 +163,10 @@ func errToAdbcErr(code adbc.Status, err error) error { var sqlstate [5]byte copy(sqlstate[:], []byte(sferr.SQLState)) + if sferr.SQLState == "42S02" { + code = adbc.StatusNotFound + } + return adbc.Error{ Code: code, Msg: sferr.Error(), diff --git a/python/adbc_driver_postgresql/tests/test_lowlevel.py b/python/adbc_driver_postgresql/tests/test_lowlevel.py index 1e79713b38..b4c4dcb658 100644 --- a/python/adbc_driver_postgresql/tests/test_lowlevel.py +++ b/python/adbc_driver_postgresql/tests/test_lowlevel.py @@ -33,6 +33,11 @@ def postgres( yield conn +def test_connection_get_table_schema(postgres: adbc_driver_manager.AdbcConnection): + with pytest.raises(adbc_driver_manager.ProgrammingError, match="NOT_FOUND"): + postgres.get_table_schema(None, None, "thistabledoesnotexist") + + def test_query_trivial(postgres: adbc_driver_manager.AdbcConnection) -> None: with adbc_driver_manager.AdbcStatement(postgres) as stmt: stmt.set_sql_query("SELECT 1") diff --git a/python/adbc_driver_sqlite/tests/test_lowlevel.py b/python/adbc_driver_sqlite/tests/test_lowlevel.py index 58359a11bf..9c8afcac3b 100644 --- a/python/adbc_driver_sqlite/tests/test_lowlevel.py +++ b/python/adbc_driver_sqlite/tests/test_lowlevel.py @@ -29,6 +29,11 @@ def sqlite(): yield conn +def test_connection_get_table_schema(sqlite): + with pytest.raises(adbc_driver_manager.ProgrammingError, match="NOT_FOUND"): + sqlite.get_table_schema(None, None, "thistabledoesnotexist") + + def test_query_trivial(sqlite): with adbc_driver_manager.AdbcStatement(sqlite) as stmt: stmt.set_sql_query("SELECT 1")