Skip to content

Commit

Permalink
feat(c/driver/postgresql): Enable basic connect/query workflow for Re…
Browse files Browse the repository at this point in the history
…dshift (#2219)

Just following up on #1563 to see if the missing `typarray` column is
the only issue. To get the details right might be a large project, but
we might be able to support a basic connection without too much effort.
Paramter binding and non-COPY result fetching seem to work...the default
query fetch method (COPY) is not supported, `connection_get_info()`
fails, and at a glance, `connection_get_objects()` might be returning
incorrect results (and fails at the column depth).

``` r
library(adbcdrivermanager)

db <- adbc_database_init(
  adbcpostgresql::adbcpostgresql(),
  uri = Sys.getenv("ADBC_REDSHIFT_TEST_URI"),
  adbc.postgresql.load_array_types = FALSE
)

con <- db |> 
  adbc_connection_init()

stmt <- con |> 
  adbc_statement_init(adbc.postgresql.use_copy = FALSE)

stream <- nanoarrow::nanoarrow_allocate_array_stream()
stmt |> 
  adbc_statement_bind(data.frame(45)) |> 
  adbc_statement_set_sql_query("SELECT 1 + $1 as foofy, 'string' as foofy_string") |> 
  adbc_statement_execute_query(stream)
#> [1] -1

tibble::as_tibble(stream)
#> # A tibble: 1 × 2
#>   foofy foofy_string
#>   <dbl> <chr>       
#> 1    46 string
```

<sup>Created on 2024-10-04 with [reprex
v2.1.1](https://reprex.tidyverse.org)</sup>

---------

Co-authored-by: William Ayd <[email protected]>
  • Loading branch information
paleolimbot and WillAyd authored Nov 5, 2024
1 parent 9ac8f6c commit 28b8c1b
Show file tree
Hide file tree
Showing 13 changed files with 339 additions and 154 deletions.
91 changes: 64 additions & 27 deletions c/driver/postgresql/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "connection.h"

#include <array>
#include <cassert>
#include <cinttypes>
#include <cmath>
Expand Down Expand Up @@ -175,6 +176,13 @@ class PostgresGetObjectsHelper : public adbc::driver::GetObjectsHelper {
all_constraints_(conn, kConstraintsQueryAll),
some_constraints_(conn, ConstraintsQuery()) {}

// Allow Redshift to execute this query without constraints
// TODO(paleolimbot): Investigate to see if we can simplify the constraits query so that
// it works on both!
void SetEnableConstraints(bool enable_constraints) {
enable_constraints_ = enable_constraints;
}

Status Load(adbc::driver::GetObjectsDepth depth,
std::optional<std::string_view> catalog_filter,
std::optional<std::string_view> schema_filter,
Expand Down Expand Up @@ -262,16 +270,23 @@ class PostgresGetObjectsHelper : public adbc::driver::GetObjectsHelper {
std::optional<std::string_view> column_filter) override {
if (column_filter.has_value()) {
UNWRAP_STATUS(some_columns_.Execute(
{std::string(schema), std::string(table), std::string(*column_filter)}))
UNWRAP_STATUS(some_constraints_.Execute(
{std::string(schema), std::string(table), std::string(*column_filter)}))
{std::string(schema), std::string(table), std::string(*column_filter)}));
next_column_ = some_columns_.Row(-1);
next_constraint_ = some_constraints_.Row(-1);
} else {
UNWRAP_STATUS(all_columns_.Execute({std::string(schema), std::string(table)}))
UNWRAP_STATUS(all_constraints_.Execute({std::string(schema), std::string(table)}))
UNWRAP_STATUS(all_columns_.Execute({std::string(schema), std::string(table)}));
next_column_ = all_columns_.Row(-1);
next_constraint_ = all_constraints_.Row(-1);
}

if (enable_constraints_) {
if (column_filter.has_value()) {
UNWRAP_STATUS(some_constraints_.Execute(
{std::string(schema), std::string(table), std::string(*column_filter)}))
next_constraint_ = some_constraints_.Row(-1);
} else {
UNWRAP_STATUS(
all_constraints_.Execute({std::string(schema), std::string(table)}));
next_constraint_ = all_constraints_.Row(-1);
}
}

return Status::Ok();
Expand Down Expand Up @@ -348,6 +363,9 @@ class PostgresGetObjectsHelper : public adbc::driver::GetObjectsHelper {
PqResultHelper all_constraints_;
PqResultHelper some_constraints_;

// On Redshift, the constraints query fails
bool enable_constraints_{true};

// Iterator state for the catalogs/schema/table/column queries
PqResultRow next_catalog_;
PqResultRow next_schema_;
Expand Down Expand Up @@ -478,19 +496,30 @@ AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection,
for (size_t i = 0; i < info_codes_length; i++) {
switch (info_codes[i]) {
case ADBC_INFO_VENDOR_NAME:
infos.push_back({info_codes[i], "PostgreSQL"});
infos.push_back({info_codes[i], std::string(VendorName())});
break;
case ADBC_INFO_VENDOR_VERSION: {
const char* stmt = "SHOW server_version_num";
auto result_helper = PqResultHelper{conn_, std::string(stmt)};
RAISE_STATUS(error, result_helper.Execute());
auto it = result_helper.begin();
if (it == result_helper.end()) {
SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt);
return ADBC_STATUS_INTERNAL;
if (VendorName() == "Redshift") {
const std::array<int, 3>& version = VendorVersion();
std::string version_string = std::to_string(version[0]) + "." +
std::to_string(version[1]) + "." +
std::to_string(version[2]);
infos.push_back({info_codes[i], std::move(version_string)});

} else {
// Gives a version in the form 140000 instead of 14.0.0
const char* stmt = "SHOW server_version_num";
auto result_helper = PqResultHelper{conn_, std::string(stmt)};
RAISE_STATUS(error, result_helper.Execute());
auto it = result_helper.begin();
if (it == result_helper.end()) {
SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt);
return ADBC_STATUS_INTERNAL;
}
const char* server_version_num = (*it)[0].data;
infos.push_back({info_codes[i], server_version_num});
}
const char* server_version_num = (*it)[0].data;
infos.push_back({info_codes[i], server_version_num});

break;
}
case ADBC_INFO_DRIVER_NAME:
Expand Down Expand Up @@ -520,7 +549,8 @@ AdbcStatusCode PostgresConnection::GetObjects(
struct AdbcConnection* connection, int c_depth, const char* catalog,
const char* db_schema, const char* table_name, const char** table_type,
const char* column_name, struct ArrowArrayStream* out, struct AdbcError* error) {
PostgresGetObjectsHelper new_helper(conn_);
PostgresGetObjectsHelper helper(conn_);
helper.SetEnableConstraints(VendorName() != "Redshift");

const auto catalog_filter =
catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt;
Expand Down Expand Up @@ -559,9 +589,9 @@ AdbcStatusCode PostgresConnection::GetObjects(
.ToAdbc(error);
}

auto status = BuildGetObjects(&new_helper, depth, catalog_filter, schema_filter,
auto status = BuildGetObjects(&helper, depth, catalog_filter, schema_filter,
table_filter, column_filter, table_type_filter, out);
RAISE_STATUS(error, new_helper.Close());
RAISE_STATUS(error, helper.Close());
RAISE_STATUS(error, status);

return ADBC_STATUS_OK;
Expand All @@ -573,11 +603,12 @@ AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value,
if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) {
output = PQdb(conn_);
} else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) {
PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA"};
PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA()"};
RAISE_STATUS(error, result_helper.Execute());
auto it = result_helper.begin();
if (it == result_helper.end()) {
SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'");
SetError(error,
"[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA()'");
return ADBC_STATUS_INTERNAL;
}
output = (*it)[0].data;
Expand Down Expand Up @@ -989,22 +1020,22 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog,
CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), result_helper.NumRows()),
error);

ArrowError na_error;
int row_counter = 0;
for (auto row : result_helper) {
const char* colname = row[0].data;
const Oid pg_oid =
static_cast<uint32_t>(std::strtol(row[1].data, /*str_end=*/nullptr, /*base=*/10));

PostgresType pg_type;
if (type_resolver_->Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) {
SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row_counter + 1, " (\"", colname,
"\") has unknown type code ", pg_oid);
if (type_resolver_->FindWithDefault(pg_oid, &pg_type) != NANOARROW_OK) {
SetError(error, "%s%d%s%s%s%" PRIu32, "Error resolving type code for column #",
row_counter + 1, " (\"", colname, "\") with oid ", pg_oid);
final_status = ADBC_STATUS_NOT_IMPLEMENTED;
break;
}
CHECK_NA(INTERNAL,
pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter]),
pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter],
std::string(VendorName())),
error);
row_counter++;
}
Expand Down Expand Up @@ -1136,4 +1167,10 @@ AdbcStatusCode PostgresConnection::SetOptionInt(const char* key, int64_t value,
return ADBC_STATUS_NOT_IMPLEMENTED;
}

std::string_view PostgresConnection::VendorName() { return database_->VendorName(); }

const std::array<int, 3>& PostgresConnection::VendorVersion() {
return database_->VendorVersion();
}

} // namespace adbcpq
3 changes: 3 additions & 0 deletions c/driver/postgresql/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include <array>
#include <cstdint>
#include <memory>

Expand Down Expand Up @@ -73,6 +74,8 @@ class PostgresConnection {
return type_resolver_;
}
bool autocommit() const { return autocommit_; }
std::string_view VendorName();
const std::array<int, 3>& VendorVersion();

private:
std::shared_ptr<PostgresDatabase> database_;
Expand Down
2 changes: 1 addition & 1 deletion c/driver/postgresql/copy/postgres_copy_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class PostgresCopyStreamTester {
public:
ArrowErrorCode Init(const PostgresType& root_type, ArrowError* error = nullptr) {
NANOARROW_RETURN_NOT_OK(reader_.Init(root_type));
NANOARROW_RETURN_NOT_OK(reader_.InferOutputSchema(error));
NANOARROW_RETURN_NOT_OK(reader_.InferOutputSchema("PostgreSQL Tester", error));
NANOARROW_RETURN_NOT_OK(reader_.InitFieldReaders(error));
return NANOARROW_OK;
}
Expand Down
5 changes: 3 additions & 2 deletions c/driver/postgresql/copy/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -972,10 +972,11 @@ class PostgresCopyStreamReader {
return NANOARROW_OK;
}

ArrowErrorCode InferOutputSchema(ArrowError* error) {
ArrowErrorCode InferOutputSchema(const std::string& vendor_name, ArrowError* error) {
schema_.reset();
ArrowSchemaInit(schema_.get());
NANOARROW_RETURN_NOT_OK(root_reader_.InputType().SetSchema(schema_.get()));
NANOARROW_RETURN_NOT_OK(
root_reader_.InputType().SetSchema(schema_.get(), vendor_name));
return NANOARROW_OK;
}

Expand Down
Loading

0 comments on commit 28b8c1b

Please sign in to comment.