From cc0c4e9c01b1feb7c175b22f7cc1b0d019751745 Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Tue, 21 Nov 2023 08:49:40 -0800 Subject: [PATCH] feat(go/adbc/drivermgr): Implement Remaining CGO Wrapper Methods that are Supported by SQLite Driver (#1304) # What? Implementations for the following methods in the CGO wrapper for `adbc_driver_manager`: - `GetTableSchema` - `GetTableTypes` - `Commit` - `Rollback` - `GetParameterSchema` - `BindStream` # Why? Functionality exists in C++ driver manager but not yet accessible via Go driver interface. # Notes Three methods in the wrapper remain unimplemented: `ExecutePartitions`, `ReadPartition`, and `SetSubstraitPlan`. These methods are not currently supported by the SQLite driver, which is the primary test target for these changes. It is still possible to implement them in the drivermgr wrapper without support in specific drivers, but it does make it more difficult to verify correct behavior. The effort to add those methods will likely involve some additional work to ensure we are able to test their behaviors, so they are being left out of this current round of implementations. Closes part of: #1291 --- go/adbc/drivermgr/wrapper.go | 92 ++++++++++++- go/adbc/drivermgr/wrapper_sqlite_test.go | 163 ++++++++++++++++++++++- 2 files changed, 244 insertions(+), 11 deletions(-) diff --git a/go/adbc/drivermgr/wrapper.go b/go/adbc/drivermgr/wrapper.go index 1e131e51c7..07bb94b814 100644 --- a/go/adbc/drivermgr/wrapper.go +++ b/go/adbc/drivermgr/wrapper.go @@ -32,6 +32,10 @@ package drivermgr // return (struct ArrowArray*)malloc(sizeof(struct ArrowArray)); // } // +// struct ArrowArrayStream* allocArrStream() { +// return (struct ArrowArrayStream*)malloc(sizeof(struct ArrowArrayStream)); +// } +// import "C" import ( "context" @@ -186,6 +190,15 @@ func getRdr(out *C.struct_ArrowArrayStream) (array.RecordReader, error) { return rdr.(array.RecordReader), nil } +func getSchema(out *C.struct_ArrowSchema) (*arrow.Schema, error) { + // Maybe: ImportCArrowSchema should perform this check? + if out.format == nil { + return nil, nil + } + + return cdata.ImportCArrowSchema((*cdata.CArrowSchema)(unsafe.Pointer(out))) +} + type cnxn struct { conn *C.struct_AdbcConnection } @@ -255,19 +268,68 @@ func (c *cnxn) GetObjects(_ context.Context, depth adbc.ObjectDepth, catalog, db } func (c *cnxn) GetTableSchema(_ context.Context, catalog, dbSchema *string, tableName string) (*arrow.Schema, error) { - return nil, &adbc.Error{Code: adbc.StatusNotImplemented} + var ( + schema C.struct_ArrowSchema + err C.struct_AdbcError + catalog_ *C.char + dbSchema_ *C.char + tableName_ *C.char + ) + + if catalog != nil { + catalog_ = C.CString(*catalog) + defer C.free(unsafe.Pointer(catalog_)) + } + + if dbSchema != nil { + dbSchema_ = C.CString(*dbSchema) + defer C.free(unsafe.Pointer(dbSchema_)) + } + + tableName_ = C.CString(tableName) + defer C.free(unsafe.Pointer(tableName_)) + + if code := adbc.Status(C.AdbcConnectionGetTableSchema(c.conn, catalog_, dbSchema_, tableName_, &schema, &err)); code != adbc.StatusOK { + return nil, toAdbcError(code, &err) + } + + return getSchema(&schema) } func (c *cnxn) GetTableTypes(context.Context) (array.RecordReader, error) { - return nil, &adbc.Error{Code: adbc.StatusNotImplemented} + var ( + out C.struct_ArrowArrayStream + err C.struct_AdbcError + ) + + if code := adbc.Status(C.AdbcConnectionGetTableTypes(c.conn, &out, &err)); code != adbc.StatusOK { + return nil, toAdbcError(code, &err) + } + return getRdr(&out) } func (c *cnxn) Commit(context.Context) error { - return &adbc.Error{Code: adbc.StatusNotImplemented} + var ( + err C.struct_AdbcError + ) + + if code := adbc.Status(C.AdbcConnectionCommit(c.conn, &err)); code != adbc.StatusOK { + return toAdbcError(code, &err) + } + + return nil } func (c *cnxn) Rollback(context.Context) error { - return &adbc.Error{Code: adbc.StatusNotImplemented} + var ( + err C.struct_AdbcError + ) + + if code := adbc.Status(C.AdbcConnectionRollback(c.conn, &err)); code != adbc.StatusOK { + return toAdbcError(code, &err) + } + + return nil } func (c *cnxn) NewStatement() (adbc.Statement, error) { @@ -405,11 +467,29 @@ func (s *stmt) Bind(_ context.Context, values arrow.Record) error { } func (s *stmt) BindStream(_ context.Context, stream array.RecordReader) error { - return &adbc.Error{Code: adbc.StatusNotImplemented} + var ( + arrStream = C.allocArrStream() + cdArrStream = (*cdata.CArrowArrayStream)(unsafe.Pointer(arrStream)) + err C.struct_AdbcError + ) + cdata.ExportRecordReader(stream, cdArrStream) + if code := adbc.Status(C.AdbcStatementBindStream(s.st, arrStream, &err)); code != adbc.StatusOK { + return toAdbcError(code, &err) + } + return nil } func (s *stmt) GetParameterSchema() (*arrow.Schema, error) { - return nil, &adbc.Error{Code: adbc.StatusNotImplemented} + var ( + schema C.struct_ArrowSchema + err C.struct_AdbcError + ) + + if code := adbc.Status(C.AdbcStatementGetParameterSchema(s.st, &schema, &err)); code != adbc.StatusOK { + return nil, toAdbcError(code, &err) + } + + return getSchema(&schema) } func (s *stmt) ExecutePartitions(context.Context) (*arrow.Schema, adbc.Partitions, int64, error) { diff --git a/go/adbc/drivermgr/wrapper_sqlite_test.go b/go/adbc/drivermgr/wrapper_sqlite_test.go index 580b046731..c33adf2792 100644 --- a/go/adbc/drivermgr/wrapper_sqlite_test.go +++ b/go/adbc/drivermgr/wrapper_sqlite_test.go @@ -53,20 +53,25 @@ func (dm *DriverMgrSuite) SetupSuite() { }) dm.NoError(err) - db, err := dm.db.Open(dm.ctx) + cnxn, err := dm.db.Open(dm.ctx) dm.NoError(err) - defer db.Close() + defer cnxn.Close() - stmt, err := db.NewStatement() + stmt, err := cnxn.NewStatement() dm.NoError(err) defer stmt.Close() - err = stmt.SetSqlQuery("CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)") - dm.NoError(err) + dm.NoError(stmt.SetSqlQuery("CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)")) nrows, err := stmt.ExecuteUpdate(dm.ctx) dm.NoError(err) dm.Equal(int64(0), nrows) + + dm.NoError(stmt.SetSqlQuery("INSERT INTO test_table (id, name) VALUES (1, 'test')")) + + nrows, err = stmt.ExecuteUpdate(dm.ctx) + dm.NoError(err) + dm.Equal(int64(1), nrows) } func (dm *DriverMgrSuite) SetupTest() { @@ -334,6 +339,83 @@ func (dm *DriverMgrSuite) TestGetObjectsTableType() { dm.False(rdr.Next()) } +func (dm *DriverMgrSuite) TestGetTableSchema() { + schema, err := dm.conn.GetTableSchema(dm.ctx, nil, nil, "test_table") + dm.NoError(err) + + expSchema := arrow.NewSchema( + []arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "name", Type: arrow.BinaryTypes.String, Nullable: true}, + }, nil) + dm.True(expSchema.Equal(schema)) +} + +func (dm *DriverMgrSuite) TestGetTableSchemaInvalidTable() { + _, err := dm.conn.GetTableSchema(dm.ctx, nil, nil, "unknown_table") + dm.Error(err) +} + +func (dm *DriverMgrSuite) TestGetTableSchemaCatalog() { + catalog := "does_not_exist" + schema, err := dm.conn.GetTableSchema(dm.ctx, &catalog, nil, "test_table") + dm.NoError(err) + dm.Nil(schema) +} + +func (dm *DriverMgrSuite) TestGetTableSchemaDBSchema() { + dbSchema := "does_not_exist" + schema, err := dm.conn.GetTableSchema(dm.ctx, nil, &dbSchema, "test_table") + dm.NoError(err) + dm.Nil(schema) +} + +func (dm *DriverMgrSuite) TestGetTableTypes() { + rdr, err := dm.conn.GetTableTypes(dm.ctx) + dm.NoError(err) + defer rdr.Release() + + expSchema := adbc.TableTypesSchema + dm.True(expSchema.Equal(rdr.Schema())) + dm.True(rdr.Next()) + + rec := rdr.Record() + dm.Equal(int64(2), rec.NumRows()) + + expTableTypes := []string{"table", "view"} + dm.Contains(expTableTypes, rec.Column(0).ValueStr(0)) + dm.Contains(expTableTypes, rec.Column(0).ValueStr(1)) + dm.False(rdr.Next()) +} + +func (dm *DriverMgrSuite) TestCommit() { + err := dm.conn.Commit(dm.ctx) + dm.Error(err) + dm.ErrorContains(err, "No active transaction, cannot commit") +} + +func (dm *DriverMgrSuite) TestCommitAutocommitDisabled() { + cnxnopt, ok := dm.conn.(adbc.PostInitOptions) + dm.True(ok) + + dm.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled)) + dm.NoError(dm.conn.Commit(dm.ctx)) +} + +func (dm *DriverMgrSuite) TestRollback() { + err := dm.conn.Rollback(dm.ctx) + dm.Error(err) + dm.ErrorContains(err, "No active transaction, cannot rollback") +} + +func (dm *DriverMgrSuite) TestRollbackAutocommitDisabled() { + cnxnopt, ok := dm.conn.(adbc.PostInitOptions) + dm.True(ok) + + dm.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled)) + dm.NoError(dm.conn.Rollback(dm.ctx)) +} + func (dm *DriverMgrSuite) TestSqlExecute() { query := "SELECT 1" st, err := dm.conn.NewStatement() @@ -429,6 +511,77 @@ func (dm *DriverMgrSuite) TestSqlPrepareMultipleParams() { dm.False(rdr.Next()) } +func (dm *DriverMgrSuite) TestGetParameterSchema() { + query := "SELECT ?1, ?2" + st, err := dm.conn.NewStatement() + dm.Require().NoError(err) + dm.Require().NoError(st.SetSqlQuery(query)) + defer st.Close() + + expSchema := arrow.NewSchema([]arrow.Field{ + {Name: "?1", Type: arrow.Null, Nullable: true}, + {Name: "?2", Type: arrow.Null, Nullable: true}, + }, nil) + + schema, err := st.GetParameterSchema() + dm.NoError(err) + + dm.True(expSchema.Equal(schema)) +} + +func (dm *DriverMgrSuite) TestBindStream() { + query := "SELECT ?1, ?2" + st, err := dm.conn.NewStatement() + dm.Require().NoError(err) + dm.Require().NoError(st.SetSqlQuery(query)) + defer st.Close() + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "1", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "2", Type: arrow.BinaryTypes.String, Nullable: true}, + }, nil) + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer bldr.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil) + bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"one", "two", "three"}, nil) + + rec1 := bldr.NewRecord() + defer rec1.Release() + + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{4, 5, 6}, nil) + bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"four", "five", "six"}, nil) + + rec2 := bldr.NewRecord() + defer rec2.Release() + + recsIn := []arrow.Record{rec1, rec2} + rdrIn, err := array.NewRecordReader(schema, recsIn) + dm.NoError(err) + + dm.NoError(st.BindStream(dm.ctx, rdrIn)) + + rdrOut, _, err := st.ExecuteQuery(dm.ctx) + dm.NoError(err) + defer rdrOut.Release() + + recsOut := make([]arrow.Record, 0) + for rdrOut.Next() { + rec := rdrOut.Record() + rec.Retain() + defer rec.Release() + recsOut = append(recsOut, rec) + } + + tableIn := array.NewTableFromRecords(schema, recsIn) + defer tableIn.Release() + tableOut := array.NewTableFromRecords(schema, recsOut) + defer tableOut.Release() + + dm.Truef(array.TableEqual(tableIn, tableOut), "expected: %s\ngot: %s", tableIn, tableOut) +} + func TestDriverMgr(t *testing.T) { suite.Run(t, new(DriverMgrSuite)) }