diff --git a/.clang-tidy b/.clang-tidy index ff9dbb0eb3..ff25867f4b 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -17,8 +17,9 @@ --- # Only defaults for now, slowly enable more as desired # Disable clang-analyzer-core, it fires on nanoarrow (I think it doesn't see that ArrowBufferReserve won't leave buffer->data NULL) +# Disable NewDeleteLeaks, it seems to trigger a lot on Googletest # Disable the warning about memset, etc. since it suggests C11 functions # Disable valist, it's buggy: https://github.com/llvm/llvm-project/issues/40656 -Checks: '-clang-analyzer-core.*,-clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling,-clang-analyzer-valist.Uninitialized' +Checks: '-clang-analyzer-core.*,-clang-analyzer-cplusplus.NewDeleteLeaks,-clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling,-clang-analyzer-valist.Uninitialized' FormatStyle: google UseColor: true diff --git a/.env b/.env index ff1a0e5fde..9e5ff8fae7 100644 --- a/.env +++ b/.env @@ -28,11 +28,12 @@ ARCH_SHORT=amd64 ARCH_CONDA_FORGE=linux_64_ # Default versions for various dependencies -JDK=8 +JDK=11 MANYLINUX=2014 MAVEN=3.6.3 -PYTHON=3.8 -GO=1.21.8 +PLATFORM=linux/amd64 +PYTHON=3.9 +GO=1.22.9 ARROW_MAJOR_VERSION=14 DOTNET=8.0 @@ -40,7 +41,7 @@ DOTNET=8.0 # ci/scripts/install_vcpkg.sh script. Keep in sync with apache/arrow .env. # When updating, also update the docs, which list the version of libpq/SQLite # that vcpkg (and hence our wheels) ship -VCPKG="a42af01b72c28a8e1d7b48107b33e4f286a55ef6" +VCPKG="943c5ef1c8f6b5e6ced092b242c8299caae2ff01" # These are used to tell tests where to find services for integration testing. # They are valid if the services are started with the docker-compose config. diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml index 30036eaf51..9babd1701c 100644 --- a/.github/ISSUE_TEMPLATE/bug.yml +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -30,6 +30,13 @@ body: description: What did you expect to happen? validations: required: true + - type: textarea + id: stack-trace + attributes: + label: Stack Trace + description: Please provide a stack trace if possible. + validations: + required: false - type: textarea id: repro attributes: diff --git a/.github/ISSUE_TEMPLATE/feature.yml b/.github/ISSUE_TEMPLATE/feature.yml index 3de62a242e..9716314c94 100644 --- a/.github/ISSUE_TEMPLATE/feature.yml +++ b/.github/ISSUE_TEMPLATE/feature.yml @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -name: Feature Request -description: Ask for a feature or improvement. +name: Enhancement/Feature Request +description: Suggest something that could be improved, or a new feature. labels: ["Type: enhancement"] body: - type: textarea diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 8ec5a3c2db..88bcf119c7 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -43,6 +43,8 @@ jobs: go-version-file: 'go/adbc/go.mod' check-latest: true - uses: actions/setup-python@v5 + with: + python-version: '3.x' - name: install golangci-lint run: | go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.49.0 diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 78f4ba2436..fec7499d74 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -23,6 +23,8 @@ on: - opened - edited - synchronize + - ready_for_review + - review_requested permissions: contents: read @@ -64,4 +66,4 @@ jobs: env: PR_BODY: ${{ github.event.pull_request.body }} run: | - [[ "${PR_BODY}" =~ @[a-zA-Z0-9]+ ]] && exit 1 || true + python .github/workflows/dev_pr/body_check.py "$PR_BODY" diff --git a/.github/workflows/dev_pr/body_check.py b/.github/workflows/dev_pr/body_check.py new file mode 100644 index 0000000000..10edbe0998 --- /dev/null +++ b/.github/workflows/dev_pr/body_check.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import re +import sys +import typing + +PING_RE = re.compile(r"@([a-zA-Z0-9\-]+)") +IGNORED_USERNAMES = {"dependabot"} + + +def check_pr_body(body: str) -> typing.List[str]: + """Check a PR body and return a list of reasons why it's invalid.""" + + reasons = [] + matches = PING_RE.findall(body) + for username in matches: + if username in IGNORED_USERNAMES: + continue + reasons.append(f"Please don't ping {username} in the PR description") + + return reasons + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("body", help="The PR body to check") + + args = parser.parse_args() + + print(f'PR body: "{args.body}"') + print("=" * 60) + + reasons = check_pr_body(args.body) + if not reasons: + print("PR body is valid") + return 0 + + print("PR body is invalid:") + for reason in reasons: + print("-", reason) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 784b535ed9..352f3cca30 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -22,7 +22,6 @@ on: branches: - main paths: - - "adbc.h" - "c/**" - "ci/**" - "go/**" @@ -30,7 +29,6 @@ on: - ".github/workflows/integration.yml" push: paths: - - "adbc.h" - "c/**" - "ci/**" - "go/**" @@ -69,7 +67,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -125,7 +122,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -203,7 +199,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -325,7 +320,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index ac18974157..e2457f29c2 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -42,7 +42,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - java: ['8', '11', '17', '21'] + java: ['11', '17', '21', '22'] steps: - uses: actions/checkout@v4 with: @@ -69,7 +69,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - java: ['11', '17', '21'] + java: ['17', '21'] steps: - uses: actions/checkout@v4 with: diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index 144ea9cfca..770fd2a689 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -22,7 +22,6 @@ on: branches: - main paths: - - "adbc.h" - "c/**" - "ci/**" - "docs/**" @@ -33,7 +32,6 @@ on: - ".github/workflows/native-unix.yml" push: paths: - - "adbc.h" - "c/**" - "ci/**" - "docs/**" @@ -96,7 +94,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -132,12 +129,23 @@ jobs: export PATH=$RUNNER_TOOL_CACHE/go/${GO_VERSION}/${{ matrix.goarch }}/bin:$PATH ./ci/scripts/go_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" + # XXX: GitHub broke upload/download-artifact. To avoid symlinks being + # converted into weird files, tar the files ourselves first. + # https://github.com/apache/arrow-adbc/issues/2061 + # https://github.com/actions/download-artifact/issues/346 + + - name: tar artifacts + shell: bash -l {0} + run: | + cd + tar czf ~/local.tgz local + - uses: actions/upload-artifact@v4 with: name: driver-manager-${{ matrix.os }} retention-days: 3 path: | - ~/local + ~/local.tgz # ------------------------------------------------------------ # C/C++ (builds and tests) @@ -169,7 +177,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -206,6 +213,67 @@ jobs: run: | ./ci/scripts/cpp_test.sh "$(pwd)/build" + drivers-test-meson: + name: "Meson - C/C++ (Conda/${{ matrix.os }})" + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: ["ubuntu-latest"] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + - name: Install Dependencies + run: | + sudo apt update + sudo apt install -y libpq-dev ninja-build + - name: Get required Go version + run: | + (. .env && echo "GO_VERSION=${GO}") >> $GITHUB_ENV + - uses: actions/setup-go@v5 + with: + go-version: "${{ env.GO_VERSION }}" + check-latest: true + cache: true + cache-dependency-path: go/adbc/go.sum + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install Meson via Python + run: pip install meson + - name: Start SQLite server, Dremio, and postgresql + shell: bash -l {0} + run: | + env POSTSGRES_VERSION=16 docker compose up --detach --wait \ + dremio \ + dremio-init \ + flightsql-test \ + flightsql-sqlite-test \ + postgres-test + pip install python-dotenv[cli] + python -m dotenv -f .env list --format simple | tee -a $GITHUB_ENV + - name: Build + run: | + meson setup \ + -Db_sanitize=address,undefined \ + -Ddriver_manager=true \ + -Dflightsql=true \ + -Dpostgresql=true \ + -Dsnowflake=true \ + -Dsqlite=true \ + -Dtests=true \ + c c/build + meson compile -C c/build + - name: Test + run: | + meson test -C c/build --print-errorlogs + - name: Stop SQLite server, Dremio, and postgresql + shell: bash -l {0} + run: | + docker compose down + clang-tidy-conda: name: "clang-tidy" runs-on: ubuntu-latest @@ -226,7 +294,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -276,7 +343,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -296,7 +362,13 @@ jobs: - uses: actions/download-artifact@v4 with: name: driver-manager-${{ matrix.os }} - path: ~/local + path: "~" + + - name: untar artifacts + shell: bash -l {0} + run: | + cd + tar xvf ~/local.tgz - name: Build GLib Driver Manager shell: bash -l {0} @@ -318,11 +390,18 @@ jobs: strategy: matrix: os: ["macos-13", "macos-latest", "ubuntu-latest", "windows-latest"] + permissions: + contents: 'read' + id-token: 'write' steps: - uses: actions/checkout@v4 with: fetch-depth: 0 persist-credentials: false + - uses: 'google-github-actions/auth@v2' + continue-on-error: true # if auth fails, bigquery driver tests should skip + with: + workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }} - uses: actions/setup-go@v5 with: go-version-file: "go/adbc/go.mod" @@ -369,11 +448,18 @@ jobs: goarch: x64 env: CGO_ENABLED: "1" + permissions: + contents: 'read' + id-token: 'write' steps: - uses: actions/checkout@v4 with: fetch-depth: 0 persist-credentials: false + - uses: 'google-github-actions/auth@v2' + continue-on-error: true # if auth fails, bigquery driver tests should skip + with: + workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }} - name: Get required Go version run: | (. .env && echo "GO_VERSION=${GO}") >> $GITHUB_ENV @@ -389,7 +475,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -418,7 +503,13 @@ jobs: - uses: actions/download-artifact@v4 with: name: driver-manager-${{ matrix.os }} - path: ~/local + path: "~" + + - name: untar artifacts + shell: bash -l {0} + run: | + cd + tar xvf ~/local.tgz - name: Go Build shell: bash -l {0} @@ -451,7 +542,7 @@ jobs: strategy: matrix: os: ["macos-13", "macos-latest", "ubuntu-latest"] - python: ["3.9", "3.11"] + python: ["3.9", "3.12"] env: # Required for macOS # https://conda-forge.org/docs/maintainer/knowledge_base.html#newer-c-features-with-old-sdk @@ -476,7 +567,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -502,7 +592,13 @@ jobs: - uses: actions/download-artifact@v4 with: name: driver-manager-${{ matrix.os }} - path: ~/local + path: "~" + + - name: untar artifacts + shell: bash -l {0} + run: | + cd + tar xvf ~/local.tgz - name: Build shell: bash -l {0} @@ -575,7 +671,7 @@ jobs: strategy: matrix: os: ["ubuntu-latest"] - python: ["3.11"] + python: ["3.12"] steps: - uses: actions/checkout@v4 with: @@ -593,7 +689,6 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -609,10 +704,17 @@ jobs: if: matrix.os == 'ubuntu-latest' run: | sudo sysctl vm.mmap_rnd_bits=28 + - uses: actions/download-artifact@v4 with: name: driver-manager-${{ matrix.os }} - path: ~/local + path: "~" + + - name: untar artifacts + shell: bash -l {0} + run: | + cd + tar xvf ~/local.tgz - name: Build Python shell: bash -l {0} diff --git a/.github/workflows/native-windows.yml b/.github/workflows/native-windows.yml index e78403d03a..9fb3caf228 100644 --- a/.github/workflows/native-windows.yml +++ b/.github/workflows/native-windows.yml @@ -22,7 +22,6 @@ on: branches: - main paths: - - "adbc.h" - "c/**" - "ci/**" - "glib/**" @@ -32,7 +31,6 @@ on: - ".github/workflows/native-windows.yml" push: paths: - - "adbc.h" - "c/**" - "ci/**" - "glib/**" @@ -79,14 +77,13 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-mamba: true - name: Install Dependencies - shell: bash -l {0} + shell: pwsh run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt + mamba install -c conda-forge ` + --file ci\conda_env_cpp.txt # Force bundled gtest mamba uninstall gtest @@ -95,6 +92,7 @@ jobs: env: BUILD_ALL: "1" # TODO(apache/arrow-adbc#634) + BUILD_DRIVER_BIGQUERY: "0" BUILD_DRIVER_FLIGHTSQL: "0" BUILD_DRIVER_SNOWFLAKE: "0" run: | @@ -136,14 +134,13 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-mamba: true - name: Install Dependencies - shell: bash -l {0} + shell: pwsh run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt + mamba install -c conda-forge ` + --file ci\conda_env_cpp.txt # Force bundled gtest mamba uninstall gtest @@ -213,14 +210,13 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-mamba: true - name: Install Dependencies - shell: bash -l {0} + shell: pwsh run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt + mamba install -c conda-forge ` + --file ci\conda_env_cpp.txt - uses: actions/setup-go@v5 with: go-version: "${{ env.GO_VERSION }}" @@ -281,16 +277,15 @@ jobs: key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-mamba: true - name: Install Dependencies - shell: bash -l {0} + shell: pwsh run: | - mamba install -c conda-forge \ - python=${{ matrix.python }} \ - --file ci/conda_env_cpp.txt \ - --file ci/conda_env_python.txt + mamba install -c conda-forge ` + python=${{ matrix.python }} ` + --file ci\conda_env_cpp.txt ` + --file ci\conda_env_python.txt - uses: actions/download-artifact@v4 with: diff --git a/.github/workflows/nightly-verify.yml b/.github/workflows/nightly-verify.yml index dfbd3f8bfa..ae6d7f87da 100644 --- a/.github/workflows/nightly-verify.yml +++ b/.github/workflows/nightly-verify.yml @@ -138,7 +138,6 @@ jobs: # The Unix script will set up conda itself if: matrix.os == 'windows-latest' with: - miniforge-variant: Mambaforge miniforge-version: latest use-mamba: true @@ -159,6 +158,9 @@ jobs: VERBOSE: "1" VERIFICATION_MOCK_DIST_DIR: ${{ github.workspace }} run: | + # Rust uses a lot of disk space, free up some space + # https://github.com/actions/runner-images/issues/2840 + sudo rm -rf "$AGENT_TOOLSDIRECTORY" ./arrow-adbc/dev/release/verify-release-candidate.sh $VERSION 0 - name: Verify @@ -174,3 +176,23 @@ jobs: VERIFICATION_MOCK_DIST_DIR: ${{ github.workspace }}\apache-arrow-adbc-${{ env.VERSION }}-rc0 run: | .\arrow-adbc\dev\release\verify-release-candidate.ps1 $env:VERSION 0 + + source-docker: + name: "Run Docker Tests" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + path: arrow-adbc + persist-credentials: false + + - name: cpp-clang-latest + run: | + pushd arrow-adbc + docker compose run --rm cpp-clang-latest + + - name: python-debug + run: | + pushd arrow-adbc + docker compose run -e PYTHON=3.12 --rm python-debug diff --git a/.github/workflows/nightly-website.yml b/.github/workflows/nightly-website.yml index 0cf28a7b54..c4a8193ddf 100644 --- a/.github/workflows/nightly-website.yml +++ b/.github/workflows/nightly-website.yml @@ -74,6 +74,11 @@ jobs: with: name: docs path: temp + # To use pip below, we need to install our own Python; the system Python's + # pip won't let us install packages without a scary flag. + - uses: actions/setup-python@v5 + with: + python-version: '3.x' - name: Build shell: bash run: | diff --git a/.github/workflows/packaging.yml b/.github/workflows/packaging.yml index 6b26e99202..135ee1c992 100644 --- a/.github/workflows/packaging.yml +++ b/.github/workflows/packaging.yml @@ -23,7 +23,6 @@ on: - main paths: - ".env" - - "adbc.h" - "c/**" - "ci/**" - "glib/**" @@ -133,8 +132,9 @@ jobs: popd - name: Upload Go binaries - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: + name: go-${{ matrix.os }} retention-days: 7 path: | adbc/go/adbc/pkg/libadbc_driver_flightsql.* @@ -168,13 +168,15 @@ jobs: echo "schedule: ${{ github.event.schedule }}" >> $GITHUB_STEP_SUMMARY echo "ref: ${{ github.ref }}" >> $GITHUB_STEP_SUMMARY - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: + pattern: go-* path: adbc/go/adbc/pkg + merge-multiple: true - name: Copy Go binaries run: | - pushd adbc/go/adbc/pkg/artifact + pushd adbc/go/adbc/pkg/ cp *.dll ../ cp *.so ../ cp *.dylib ../ @@ -244,7 +246,7 @@ jobs: docs.tgz java: - name: "Java 1.8" + name: "Java 11" runs-on: ubuntu-latest needs: - source @@ -518,7 +520,6 @@ jobs: - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true @@ -559,13 +560,15 @@ jobs: fail-fast: false matrix: arch: ["amd64", "arm64v8"] - manylinux_version: ["2014"] is_pr: - ${{ startsWith(github.ref, 'refs/pull/') }} exclude: # Don't run arm64v8 build on PRs since the build is excessively slow - arch: arm64v8 is_pr: true + include: + - {arch: amd64, platform: linux/amd64} + - {arch: arm64v8, platform: linux/arm64/v8} steps: - uses: actions/download-artifact@v4 with: @@ -596,12 +599,16 @@ jobs: - name: Build wheel env: ARCH: ${{ matrix.arch }} - MANYLINUX: ${{ matrix.manylinux_version }} + PLATFORM: ${{ matrix.platform }} run: | pushd adbc docker compose run \ -e SETUPTOOLS_SCM_PRETEND_VERSION=$VERSION \ - python-wheel-manylinux + python-wheel-manylinux-build + + docker compose run \ + -e SETUPTOOLS_SCM_PRETEND_VERSION=$VERSION \ + python-wheel-manylinux-relocate popd - name: Archive wheels @@ -610,24 +617,17 @@ jobs: name: python-${{ matrix.arch }}-manylinux${{ matrix.manylinux_version }} retention-days: 7 path: | + adbc/python/adbc_driver_bigquery/repaired_wheels/*.whl adbc/python/adbc_driver_flightsql/repaired_wheels/*.whl adbc/python/adbc_driver_manager/repaired_wheels/*.whl adbc/python/adbc_driver_postgresql/repaired_wheels/*.whl adbc/python/adbc_driver_sqlite/repaired_wheels/*.whl adbc/python/adbc_driver_snowflake/repaired_wheels/*.whl - - name: Test wheel 3.8 - env: - ARCH: ${{ matrix.arch }} - MANYLINUX: ${{ matrix.manylinux_version }} - run: | - pushd adbc - env PYTHON=3.8 docker compose run python-wheel-manylinux-test - - name: Test wheel 3.9 env: ARCH: ${{ matrix.arch }} - MANYLINUX: ${{ matrix.manylinux_version }} + PLATFORM: ${{ matrix.platform }} run: | pushd adbc env PYTHON=3.9 docker compose run python-wheel-manylinux-test @@ -635,7 +635,7 @@ jobs: - name: Test wheel 3.10 env: ARCH: ${{ matrix.arch }} - MANYLINUX: ${{ matrix.manylinux_version }} + PLATFORM: ${{ matrix.platform }} run: | pushd adbc env PYTHON=3.10 docker compose run python-wheel-manylinux-test @@ -643,7 +643,7 @@ jobs: - name: Test wheel 3.11 env: ARCH: ${{ matrix.arch }} - MANYLINUX: ${{ matrix.manylinux_version }} + PLATFORM: ${{ matrix.platform }} run: | pushd adbc env PYTHON=3.11 docker compose run python-wheel-manylinux-test @@ -651,7 +651,7 @@ jobs: - name: Test wheel 3.12 env: ARCH: ${{ matrix.arch }} - MANYLINUX: ${{ matrix.manylinux_version }} + PLATFORM: ${{ matrix.platform }} run: | pushd adbc env PYTHON=3.12 docker compose run python-wheel-manylinux-test @@ -728,11 +728,10 @@ jobs: if: matrix.arch == 'amd64' run: | pushd adbc - sudo ci/scripts/install_python.sh macos 3.8 sudo ci/scripts/install_python.sh macos 3.9 popd - - name: Install Python (AMD64 only) + - name: Install Python run: | pushd adbc sudo ci/scripts/install_python.sh macos 3.10 @@ -748,6 +747,7 @@ jobs: $PYTHON -m venv build-env source build-env/bin/activate ./ci/scripts/python_wheel_unix_build.sh $ARCH $(pwd) $(pwd)/build + ./ci/scripts/python_wheel_unix_relocate.sh $ARCH $(pwd) $(pwd)/build popd - name: Archive wheels @@ -756,23 +756,13 @@ jobs: name: python-${{ matrix.arch }}-macos retention-days: 7 path: | + adbc/python/adbc_driver_bigquery/repaired_wheels/*.whl adbc/python/adbc_driver_flightsql/repaired_wheels/*.whl adbc/python/adbc_driver_manager/repaired_wheels/*.whl adbc/python/adbc_driver_postgresql/repaired_wheels/*.whl adbc/python/adbc_driver_sqlite/repaired_wheels/*.whl adbc/python/adbc_driver_snowflake/repaired_wheels/*.whl - - name: Test wheel 3.8 - if: matrix.arch == 'amd64' - run: | - pushd adbc - - /Library/Frameworks/Python.framework/Versions/3.8/bin/python3.8 -m venv test-env-38 - source test-env-38/bin/activate - export PYTHON_VERSION=3.8 - ./ci/scripts/python_wheel_unix_test.sh $(pwd) - deactivate - - name: Test wheel 3.9 if: matrix.arch == 'amd64' run: | @@ -822,7 +812,7 @@ jobs: strategy: fail-fast: false matrix: - python_version: ["3.9", "3.10", "3.11", "3.12"] + python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] env: PYTHON_VERSION: "${{ matrix.python_version }}" # Where to install vcpkg @@ -905,6 +895,7 @@ jobs: name: python${{ matrix.python_version }}-windows retention-days: 7 path: | + adbc/python/adbc_driver_bigquery/repaired_wheels/*.whl adbc/python/adbc_driver_flightsql/repaired_wheels/*.whl adbc/python/adbc_driver_manager/repaired_wheels/*.whl adbc/python/adbc_driver_postgresql/repaired_wheels/*.whl @@ -964,6 +955,7 @@ jobs: name: python-sdist retention-days: 7 path: | + adbc/python/adbc_driver_bigquery/dist/*.tar.gz adbc/python/adbc_driver_flightsql/dist/*.tar.gz adbc/python/adbc_driver_manager/dist/*.tar.gz adbc/python/adbc_driver_postgresql/dist/*.tar.gz @@ -1056,7 +1048,6 @@ jobs: path: conda-packages - uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest use-only-tar-bz2: false use-mamba: true diff --git a/.github/workflows/r-basic.yml b/.github/workflows/r-basic.yml index e698b34bb4..5320e794e6 100644 --- a/.github/workflows/r-basic.yml +++ b/.github/workflows/r-basic.yml @@ -23,7 +23,6 @@ on: branches: - main paths: - - "adbc.h" - "c/**" - "go/adbc/driver/**" - "go/adbc/pkg/**" diff --git a/.github/workflows/r-check.yml b/.github/workflows/r-check.yml index 0b368ca330..cc4cd9b7c6 100644 --- a/.github/workflows/r-check.yml +++ b/.github/workflows/r-check.yml @@ -58,9 +58,23 @@ jobs: PKG_CONFIG_PATH="${PKG_CONFIG_PATH}:$(brew --prefix libpq)/lib/pkgconfig:$(brew --prefix openssl)/lib/pkgconfig" echo "PKG_CONFIG_PATH=${PKG_CONFIG_PATH}" >> $GITHUB_ENV - - uses: r-lib/actions/setup-r-dependencies@v2 + # Usually, pak::pkg_install() will run bootstrap.R if it is included and is declared; + # however, this doesn't work for local:: for some reason (which is what + # setup-r-dependencies uses under the hood) + - name: Bootstrap R Package + run: | + pushd r/adbcdrivermanager + R -e 'if (!requireNamespace("nanoarrow", quietly = TRUE)) install.packages("nanoarrow", repos = "https://cloud.r-project.org/")' + R CMD INSTALL . --preclean + popd + pushd "r/${{ inputs.pkg }}" + Rscript bootstrap.R + popd + shell: bash + + - uses: r-lib/actions/setup-r-dependencies@f4937e0dc26f9b99c969cd3e4ca943b576e7f991 with: - extra-packages: any::rcmdcheck, local::../adbcdrivermanager + extra-packages: any::rcmdcheck needs: check working-directory: r/${{ inputs.pkg }} diff --git a/.github/workflows/r-extended.yml b/.github/workflows/r-extended.yml index ccfd898ac8..1ca2e821c2 100644 --- a/.github/workflows/r-extended.yml +++ b/.github/workflows/r-extended.yml @@ -44,7 +44,7 @@ jobs: matrix: rversion: [oldrel, release, devel] os: [macOS, windows, ubuntu] - pkg: [adbcdrivermanager, adbcsqlite, adbcpostgresql, adbcflightsql, adbcsnowflake] + pkg: [adbcdrivermanager, adbcsqlite, adbcpostgresql, adbcflightsql, adbcsnowflake, adbcbigquery] fail-fast: false uses: ./.github/workflows/r-check.yml @@ -61,7 +61,7 @@ jobs: matrix: rversion: ["3.6", "4.0", "4.1"] os: [ubuntu] - pkg: [adbcdrivermanager, adbcsqlite, adbcpostgresql, adbcflightsql, adbcsnowflake] + pkg: [adbcdrivermanager, adbcsqlite, adbcpostgresql, adbcflightsql, adbcsnowflake, adbcbigquery] fail-fast: false uses: ./.github/workflows/r-check.yml @@ -117,9 +117,22 @@ jobs: run: | sudo apt-get install -y valgrind - - uses: r-lib/actions/setup-r-dependencies@v2 + # Usually, pak::pkg_install() will run bootstrap.R if it is included and is declared; + # however, this doesn't work for local:: for some reason (which is what + # setup-r-dependencies uses under the hood) + - name: Bootstrap R Package + run: | + pushd r/adbcdrivermanager + R -e 'if (!requireNamespace("nanoarrow", quietly = TRUE)) install.packages("nanoarrow", repos = "https://cloud.r-project.org/")' + R CMD INSTALL . --preclean + popd + pushd "r/${{ matrix.pkg }}" + Rscript bootstrap.R + popd + shell: bash + + - uses: r-lib/actions/setup-r-dependencies@f4937e0dc26f9b99c969cd3e4ca943b576e7f991 with: - extra-packages: local::../adbcdrivermanager working-directory: r/${{ matrix.pkg }} - name: Start postgres test database diff --git a/.github/workflows/r-standard.yml b/.github/workflows/r-standard.yml index 09cd1f1007..907d779844 100644 --- a/.github/workflows/r-standard.yml +++ b/.github/workflows/r-standard.yml @@ -17,8 +17,7 @@ name: R (standard) -# Runs on PRs that touch the R packages and when pushing files the R -# package uses to main. +# Runs on PRs that touch the R packages directly on: pull_request: branches: @@ -32,7 +31,6 @@ on: branches: - main paths: - - "adbc.h" - "c/**" - "go/adbc/driver/**" - "go/adbc/pkg/**" @@ -52,7 +50,7 @@ jobs: strategy: matrix: os: [ubuntu, macOS, windows] - pkg: [adbcdrivermanager, adbcsqlite, adbcpostgresql, adbcflightsql, adbcsnowflake] + pkg: [adbcdrivermanager, adbcsqlite, adbcpostgresql, adbcflightsql, adbcsnowflake, adbcbigquery] uses: ./.github/workflows/r-check.yml with: diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 8f2ff7ddf0..ab0554fe2b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -44,6 +44,9 @@ defaults: jobs: native-unix: uses: ./.github/workflows/native-unix.yml + permissions: + contents: read + id-token: write rust: needs: [native-unix] @@ -67,10 +70,37 @@ jobs: run: | rustup toolchain install stable --no-self-update rustup default stable + - name: Get required Go version + run: | + (. ../.env && echo "GO_VERSION=${GO}") >> $GITHUB_ENV + - uses: actions/setup-go@v5 + with: + go-version: "${{ env.GO_VERSION }}" + check-latest: true + cache: true + cache-dependency-path: go/adbc/go.sum + - name: Install Protoc + if: runner.os == 'Linux' + run: | + curl -L "https://github.com/protocolbuffers/protobuf/releases/download/v28.3/protoc-28.3-linux-$(uname -m).zip" -o protoc.zip + unzip protoc.zip -d $HOME/.local + echo "$HOME/.local/bin" >> "$GITHUB_PATH" + - name: Install Protoc + if: runner.os == 'macOS' + run: | + curl -L "https://github.com/protocolbuffers/protobuf/releases/download/v28.3/protoc-28.3-osx-universal_binary.zip" -o protoc.zip + unzip "protoc.zip" -d $HOME/.local + echo "$HOME/.local/bin" >> "$GITHUB_PATH" - uses: actions/download-artifact@v4 with: name: driver-manager-${{ matrix.os }} - path: ${{ github.workspace }}/build + path: "~" + - name: Untar artifacts + shell: bash -l {0} + run: | + cd + mkdir -p ${{ github.workspace }}/build + tar xvf ~/local.tgz -C ${{ github.workspace }}/build --strip-components=1 - name: Set dynamic linker path if: matrix.os == 'ubuntu-latest' run: | @@ -83,10 +113,16 @@ jobs: if: matrix.os == 'macos-13' run: | echo "DYLD_LIBRARY_PATH=/usr/local/opt/sqlite/lib:${{ github.workspace }}/build/lib:$DYLD_LIBRARY_PATH" >> "$GITHUB_ENV" + - name: Set search dir for Snowflake Go lib + run: echo "ADBC_SNOWFLAKE_GO_LIB_DIR=${{ github.workspace }}/build/lib" >> "$GITHUB_ENV" - name: Clippy run: cargo clippy --workspace --all-targets --all-features -- -Dwarnings - name: Test run: cargo test --workspace --all-targets --all-features + # env: + # ADBC_SNOWFLAKE_TESTS: 1 + # ADBC_SNOWFLAKE_URI: ${{ secrets.SNOWFLAKE_URI }} + # ADBC_SNOWFLAKE_SQL_DB: ADBC_TESTING - name: Doctests run: cargo test --workspace --doc --all-features - name: Check docs diff --git a/.github/workflows/verify.yml b/.github/workflows/verify.yml index 68562ded20..bb546eb913 100644 --- a/.github/workflows/verify.yml +++ b/.github/workflows/verify.yml @@ -30,13 +30,10 @@ on: required: false type: string default: "" - pull_request: - branches: - - main - paths: - - '.github/workflows/verify.yml' - - 'dev/release/verify-release-candidate.sh' - - 'dev/release/verify-release-candidate.ps1' + +# Don't automatically run on pull requests. While we're only using a +# read-only token below, let's play it safe since we are running code out of +# the given branch. permissions: contents: read @@ -70,6 +67,8 @@ jobs: TEST_BINARIES: "1" USE_CONDA: "1" VERBOSE: "1" + # Make this available to download_rc_binaries.py + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | ./dev/release/verify-release-candidate.sh ${{ inputs.version }} ${{ inputs.rc }} @@ -89,7 +88,6 @@ jobs: # The Unix script will set up conda itself if: matrix.os == 'windows-latest' with: - miniforge-variant: Mambaforge miniforge-version: latest use-mamba: true - name: Work around ASAN issue (GH-1617) @@ -107,6 +105,9 @@ jobs: USE_CONDA: "1" VERBOSE: "1" run: | + # Rust uses a lot of disk space, free up some space + # https://github.com/actions/runner-images/issues/2840 + sudo rm -rf "$AGENT_TOOLSDIRECTORY" ./dev/release/verify-release-candidate.sh ${{ inputs.version }} ${{ inputs.rc }} - name: Verify if: matrix.os == 'windows-latest' diff --git a/.gitignore b/.gitignore index 0c70d830db..8716f5d7b3 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,7 @@ cpp/.idea/ c/apidoc/html/ c/apidoc/latex/ c/apidoc/xml/ +c/apidoc/objects.inv docs/example.gz docs/example1.dat docs/example3.dat @@ -121,3 +122,7 @@ target/ /ci/linux-packages/yum/merged/ /ci/linux-packages/yum/repositories/ /ci/linux-packages/yum/tmp/ + +# Meson subproject support +/c/subprojects/* +!/c/subprojects/*.wrap diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fafc20df40..6d8bf873cb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: - id: trailing-whitespace exclude: "^r/.*?/_snaps/.*?.md$" - repo: https://github.com/pre-commit/mirrors-clang-format - rev: "v18.1.5" + rev: "v18.1.7" hooks: - id: clang-format types_or: [c, c++] @@ -59,11 +59,11 @@ repos: - "--linelength=90" - "--verbose=2" - repo: https://github.com/golangci/golangci-lint - rev: v1.58.2 + rev: v1.61.0 hooks: - id: golangci-lint entry: bash -c 'cd go/adbc && golangci-lint run --fix --timeout 5m' - types_or: [go] + types_or: [go, go-mod] - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks rev: v2.13.0 hooks: @@ -77,7 +77,7 @@ repos: - id: black types_or: [pyi, python] - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 + rev: 7.1.0 hooks: - id: flake8 types_or: [python] @@ -86,6 +86,10 @@ repos: hooks: - id: isort types_or: [python] + - repo: https://github.com/MarcoGorelli/cython-lint + rev: v0.16.2 + hooks: + - id: cython-lint - repo: https://github.com/vala-lang/vala-lint rev: 8ae2bb65fe66458263d94711ae4ddd978faece00 hooks: @@ -98,10 +102,10 @@ repos: pass_filenames: false entry: "./ci/scripts/run_rat_local.sh" - id: check-cgo-adbc-header - name: Ensure CGO adbc.h is sync'd + name: Ensure CGO adbc.h is syncd language: script pass_filenames: true - files: '(c/driver_manager/adbc_driver_manager\.)|(^adbc\.h)' + files: '^c/include/arrow-adbc/.*\.h$' entry: "./ci/scripts/run_cgo_drivermgr_check.sh" - repo: https://github.com/doublify/pre-commit-rust rev: v1.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index 660a192bde..eb3a26b093 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -605,3 +605,157 @@ - **rust**: add public abstract API and dummy driver implementation (#1725) - **csharp/src/Drivers**: introduce drivers for Apache systems built on Thrift (#1710) - **format**: add info codes for supported capabilities (#1649) + +## ADBC Libraries 13 (2024-07-01) + +### Versions + +- C/C++/GLib/Go/Python/Ruby: 1.1.0 +- C#: 0.13.0 +- Java: 0.13.0 +- R: 0.13.0 +- Rust: 0.13.0 + +### Fix + +- **c/driver/sqlite**: Make SQLite driver C99 compliant (#1946) +- **go/adbc/pkg**: clean up potential sites where Go GC may be exposed (#1942) +- **c/driver/postgresql**: chunk large COPY payloads (#1937) +- **go/adbc/pkg**: guard against potential crash (#1938) +- **csharp/src/Drivers/Interop/Snowflake**: Swapping PreBuildEvent to DispatchToInnerBuilds (#1909) +- **csharp/src/Drivers/Apache/Spark**: fix parameter naming convention (#1895) +- **csharp/src/Apache.Arrow.Adbc/C**: GetObjects should preserve a null tableTypes parameter value (#1894) +- **go/adbc/driver/snowflake**: Records dropped on ingestion when empty batch is present (#1866) +- **csharp**: Fix packing process (#1862) +- **csharp/src/Drivers/Apache**: set the precision and scale correctly on Decimal128Type (#1858) + +### Feat + +- Meson Support for ADBC (#1904) +- **rust**: add integration tests and some improvements (#1883) +- **csharp/src/Drivers/Apache/Spark**: extend SQL type name parsing for all types (#1911) +- **csharp/src/Drivers/Apache**: improve type name handling for CHAR/VARCHAR/DECIMAL (#1896) +- **csharp/src/Drivers/Apache**: improve GetObjects metadata returned for columns (#1884) +- **rust**: add the driver manager (#1803) +- **csharp**: redefine C# APIs to prioritize full async support (#1865) +- **csharp/src/Drivers/Apache**: extend capability of GetInfo for Spark driver (#1863) +- **csharp/src/Drivers/Apache**: add implementation for AdbcStatement.SetOption on Spark driver (#1849) +- **csharp**: Move more options to be set centrally and enable TreatWarningsAsErrors (#1852) +- **csharp**: Initial changes for ADBC 1.1 in C# implementation (#1821) +- **csharp/src/Drivers/Apache/Spark**: implement async overrides for Spark driver (#1830) + +## ADBC Libraries 14 (2024-08-30) + +### Versions + +- C/C++/GLib/Go/Python/Ruby: 1.2.0 +- C#: 0.14.0 +- Java: 0.14.0 +- R: 0.14.0 +- Rust: 0.14.0 + +### Feat + +- **go/adbc/driver/snowflake**: Keep track of all files copied and skip empty files in bulk_ingestion (#2106) +- **dev/release**: add Rust release process (#2107) +- **go/adbc/driver/bigquery**: Implement GetObjects and get tests passing (#2044) +- **csharp/src/Client**: add support for parameterized execution (#2096) +- **c/driver/postgresql**: Support queries that bind parameters and return a result (#2065) +- **c/driver/postgresql**: Support JSON and JSONB types (#2072) +- **go/adbc/driver/bigquery**: add schema to reader for BigQuery (#2050) +- **c/driver/postgresql**: Implement consuming a PGresult via the copy reader (#2029) +- **csharp/src/Drivers/BigQuery**: add support for configurable query timeouts (#2043) +- **go/adbc/driver/snowflake**: use vectorized scanner for bulk ingest (#2025) +- **c**: Add BigQuery library to Meson build system (#1994) +- **c**: Add pkgconfig support to Meson build system (#1992) +- **c/driver/postgresql**: FIXED_SIZED_LIST Writer support (#1975) +- **go/adbc/driver**: add support for Google BigQuery (#1722) +- **c/driver/postgresql**: Implement LIST/LARGE_LIST Writer (#1962) +- **c/driver/postgresql**: Read/write support for TIME64[us] (#1960) +- **c/driver/postgresql**: UInt(8/16/32) Writer (#1961) + +### Refactor + +- **c/driver/framework**: Separate C/C++ conversions and error handling into minimal "base" framework (#2090) +- **c/driver/framework**: Remove fmt as required dependency of the driver framework (#2081) +- **c**: Updated include/install paths for adbc.h (#1965) +- **c/driver/postgresql**: Factory func for CopyWriter construction (#1998) +- **c**: Check MakeArray/Batch Error codes with macro (#1959) + +### Fix + +- **go/adbc/driver/snowflake**: Bump gosnowflake to fix context error (#2091) +- **c/driver/postgresql**: Fix ingest of streams with zero arrays (#2073) +- **csharp/src/Drivers/BigQuery**: update BigQuery test cases (#2048) +- **ci**: Pin r-lib actions as a workaround for latest action updates (#2051) +- **csharp/src/Drivers/BigQuery**: update BigQuery documents (#2047) +- **go/adbc/driver/snowflake**: split files properly after reaching targetSize on ingestion (#2026) +- **c/driver/postgresql**: Ensure schema ordering is consisent and respects case sensitivity of table names (#2028) +- **docs**: update broken link (#2016) +- **docs**: correct snowflake options for bulk ingest (#2004) +- **go/adbc/driver/flightsql**: propagate headers in GetObjects (#1996) +- **c/driver/postgresql**: Fix compiler warning on gcc14 (#1990) +- **r/adbcdrivermanager**: Ensure that class of object is checked before calling R_ExternalPtrAddrFn (#1989) +- **ci**: update website_build.sh for new versioning scheme (#1972) +- **dev/release**: update C# tag (#1973) +- **c/vendor/nanoarrow**: Fix -Wreorder warning (#1966) + +## ADBC Libraries 15 (2024-11-08) + +### Versions + +- C/C++/GLib/Go/Python/Ruby: 1.3.0 +- C#: 0.15.0 +- Java: 0.15.0 +- R: 0.15.0 +- Rust: 0.15.0 + +### Feat + +- **c/driver/postgresql**: Enable basic connect/query workflow for Redshift (#2219) +- **rust/drivers/datafusion**: add support for bulk ingest (#2279) +- **csharp/src/Drivers/Apache**: convert Double to Float for Apache Spark on scalar conversion (#2296) +- **go/adbc/driver/snowflake**: update to the latest 1.12.0 gosnowflake driver (#2298) +- **csharp/src/Drivers/BigQuery**: support max stream count setting when creating read session (#2289) +- **rust/drivers**: adbc driver for datafusion (#2267) +- **go/adbc/driver/snowflake**: improve GetObjects performance and semantics (#2254) +- **c**: Implement ingestion and testing for float16, string_view, and binary_view (#2234) +- **r**: Add R BigQuery driver wrapper (#2235) +- **csharp/src/Drivers/Apache/Spark**: add request_timeout_ms option to allow longer HTTP request length (#2218) +- **go/adbc/driver/snowflake**: add support for a client config file (#2197) +- **csharp/src/Client**: Additional parameter support for DbCommand (#2195) +- **csharp/src/Drivers/Apache/Spark**: add option to ignore TLS/SSL certificate exceptions (#2188) +- **csharp/src/Drivers/Apache/Spark**: Perform scalar data type conversion for Spark over HTTP (#2152) +- **csharp/src/Drivers/Apache/Spark**: Azure HDInsight Spark Documentation (#2164) +- **c/driver/postgresql**: Implement ingestion of list types for PostgreSQL (#2153) +- **csharp/src/Drivers/Apache/Spark**: poc - Support for Apache Spark over HTTP (non-Arrow) (#2018) +- **c/driver/postgresql**: add `arrow.opaque` type metadata (#2122) + +### Fix + +- **csharp/src/Drivers/Apache**: fix float data type handling for tests on Databricks Spark (#2283) +- **go/adbc/driver/internal/driverbase**: proper unmarshalling for ConstraintColumnNames (#2285) +- **csharp/src/Drivers/Apache**: fix to workaround concurrency issue (#2282) +- **csharp/src/Drivers/Apache**: correctly handle empty response and add Client tests (#2275) +- **csharp/src/Drivers/Apache**: remove interleaved async look-ahead code (#2273) +- **c/driver_manager**: More robust error reporting for errors that occur before AdbcDatabaseInit() (#2266) +- **rust**: implement database/connection constructors without options (#2242) +- **csharp/src/Drivers**: update System.Text.Json to version 8.0.5 because of known vulnerability (#2238) +- **csharp/src/Drivers/Apache/Spark**: correct batch handling for the HiveServer2Reader (#2215) +- **go/adbc/driver/snowflake**: call GetObjects with null catalog at catalog depth (#2194) +- **csharp/src/Drivers/Apache/Spark**: correct BatchSize implementation for base reader (#2199) +- **csharp/src/Drivers/Apache/Spark**: correct precision/scale handling with zeros in fractional portion (#2198) +- **csharp/src/Drivers/BigQuery**: Fixed GBQ driver issue when results.TableReference is null (#2165) +- **go/adbc/driver/snowflake**: fix setting database and schema context after initial connection (#2169) +- **csharp/src/Drivers/Interop/Snowflake**: add test to demonstrate DEFAULT_ROLE behavior (#2151) +- **c/driver/postgresql**: Improve error reporting for queries that error before the COPY header is sent (#2134) + +### Refactor + +- **c/driver/postgresql**: cleanups for result_helper signatures (#2261) +- **c/driver/postgresql**: Use GetObjectsHelper from framework to build objects (#2189) +- **csharp/src/Drivers/Apache/Spark**: use UTF8 string for data conversion, instead of .NET String (#2192) +- **c/driver/postgresql**: Use Status for error handling in BindStream (#2187) +- **c/driver/postgresql**: Use Status instead of AdbcStatusCode/AdbcError in result helper (#2178) +- **c/driver**: Use non-objects framework components in Postgres driver (#2166) +- **c/driver/postgresql**: Use copy writer in BindStream for parameter binding (#2157) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7acb111a94..c9cbd0afb0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,8 +31,8 @@ https://github.com/apache/arrow-adbc/issues Some dependencies are required to build and test the various ADBC packages. For C/C++, you will most likely want a [Conda][conda] installation, -with [Mambaforge][mambaforge] being the most convenient distribution. -If you have Mambaforge installed, you can set up a development +with [Miniforge][miniforge] being the most convenient distribution. +If you have Miniforge installed, you can set up a development environment as follows: ```shell @@ -52,7 +52,7 @@ CMake or other build tool appropriately. However, we primarily develop and support Conda users. [conda]: https://docs.conda.io/en/latest/ -[mambaforge]: https://mamba.readthedocs.io/en/latest/installation.html +[miniforge]: https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html ### Running Integration Tests @@ -148,6 +148,43 @@ for details. [cmake-prefix-path]: https://cmake.org/cmake/help/latest/variable/CMAKE_PREFIX_PATH.html [gtest]: https://github.com/google/googletest/ +### C/C++ with Meson + +While CMake is the officially supported build generator, there is limited, +experimental support for the Meson build system. Meson offers arguably better +dependency management than CMake, with a syntax that Python developers may +find more readable. + +To use Meson, start at the c directory and run: + +```shell +$ meson setup build +``` + +For a full list of options, ``meson configure`` will bring up a pager +with sections that you can navigate. The "Project Options" section in particular +will show you what ADBC has to offer, and each option can be provided using +the form ``-D_option_:_value_``. For example, to build the a debug version of +the SQLite3 driver along with tests, you would run: + +```shell +$ meson configure -Dbuildtype=debug -Dsqlite=true -Dtests=true build +``` + +With the options set, you can then compile the project. For most dependencies, +Meson will try to find them on your system and fall back to downloading a copy +from its WrapDB for you: + +```shell +$ meson compile -C build +``` + +To run the test suite, simply run: + +```shell +$ meson test -C build +``` + ### C#/.NET Make sure [.NET Core is installed](https://dotnet.microsoft.com/en-us/download). diff --git a/README.md b/README.md index 2cf24ecc9d..cb2b9c0699 100644 --- a/README.md +++ b/README.md @@ -57,4 +57,4 @@ User documentation can be found at https://arrow.apache.org/adbc ## Development and Contributing -For detailed instructions on how to build the various ADBC libraries, see CONTRIBUTING.md. +For detailed instructions on how to build the various ADBC libraries, see [CONTRIBUTING.md](CONTRIBUTING.md). diff --git a/c/CMakeLists.txt b/c/CMakeLists.txt index 06be814267..be69103d06 100644 --- a/c/CMakeLists.txt +++ b/c/CMakeLists.txt @@ -23,6 +23,7 @@ include(BuildUtils) project(adbc VERSION "${ADBC_BASE_VERSION}" LANGUAGES C CXX) +set(CMAKE_C_STANDARD 99) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -53,6 +54,9 @@ if(ADBC_DRIVER_FLIGHTSQL) endif() if(ADBC_DRIVER_MANAGER) + install(FILES "${REPOSITORY_ROOT}/c/include/adbc_driver_manager.h" DESTINATION include) + install(FILES "${REPOSITORY_ROOT}/c/include/arrow-adbc/adbc_driver_manager.h" + DESTINATION include/arrow-adbc) add_subdirectory(driver_manager) endif() @@ -68,6 +72,10 @@ if(ADBC_DRIVER_SNOWFLAKE) add_subdirectory(driver/snowflake) endif() +if(ADBC_DRIVER_BIGQUERY) + add_subdirectory(driver/bigquery) +endif() + if(ADBC_INTEGRATION_DUCKDB) add_subdirectory(integration/duckdb) endif() @@ -118,6 +126,10 @@ LIBRARY=$" ${Python3_EXECUTABLE} -m pi if(ADBC_DRIVER_SNOWFLAKE) adbc_install_python_package(snowflake) endif() + + if(ADBC_DRIVER_BIGQUERY) + adbc_install_python_package(bigquery) + endif() endif() validate_config() diff --git a/c/apidoc/Doxyfile b/c/apidoc/Doxyfile index df72a8bd03..2f4924281c 100644 --- a/c/apidoc/Doxyfile +++ b/c/apidoc/Doxyfile @@ -500,7 +500,7 @@ EXTRACT_ALL = NO # be included in the documentation. # The default value is: NO. -EXTRACT_PRIVATE = NO +EXTRACT_PRIVATE = YES # If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual # methods of a class will be included in the documentation. @@ -891,7 +891,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../../adbc.h ../../README.md ../../c/driver_manager/adbc_driver_manager.h +INPUT = ../../c/include/arrow-adbc/adbc.h ../../README.md ../../c/include/arrow-adbc/adbc_driver_manager.h ../../c/driver/framework/ # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses @@ -920,12 +920,7 @@ INPUT_ENCODING = UTF-8 # comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, # *.vhdl, *.ucf, *.qsf and *.ice. -FILE_PATTERNS = *.c \ - *.cc \ - *.cxx \ - *.cpp \ - *.c++ \ - *.java \ +FILE_PATTERNS = *.java \ *.ii \ *.ixx \ *.ipp \ @@ -1007,7 +1002,7 @@ EXCLUDE_PATTERNS = # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories use the pattern */test/* -EXCLUDE_SYMBOLS = +EXCLUDE_SYMBOLS = ADBC ADBC_DRIVER_MANAGER_H # The EXAMPLE_PATH tag can be used to specify one or more files or directories # that contain example code fragments that are included (see the \include diff --git a/c/cmake_modules/AdbcDefines.cmake b/c/cmake_modules/AdbcDefines.cmake index 6c83cca54c..d27cb8fb3d 100644 --- a/c/cmake_modules/AdbcDefines.cmake +++ b/c/cmake_modules/AdbcDefines.cmake @@ -93,7 +93,9 @@ if(MSVC) # Don't warn about padding added after members add_compile_options(/wd4820) add_compile_options(/wd5027) + add_compile_options(/wd5039) add_compile_options(/wd5045) + add_compile_options(/wd5246) elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") diff --git a/c/cmake_modules/AdbcVersion.cmake b/c/cmake_modules/AdbcVersion.cmake index 68e8cc830c..fcb67c08c5 100644 --- a/c/cmake_modules/AdbcVersion.cmake +++ b/c/cmake_modules/AdbcVersion.cmake @@ -21,7 +21,7 @@ # ------------------------------------------------------------ # Version definitions -set(ADBC_VERSION "1.1.0-SNAPSHOT") +set(ADBC_VERSION "1.4.0-SNAPSHOT") string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ADBC_BASE_VERSION "${ADBC_VERSION}") string(REPLACE "." ";" _adbc_version_list "${ADBC_BASE_VERSION}") list(GET _adbc_version_list 0 ADBC_VERSION_MAJOR) diff --git a/c/cmake_modules/DefineOptions.cmake b/c/cmake_modules/DefineOptions.cmake index b6dd1079d2..13e6757347 100644 --- a/c/cmake_modules/DefineOptions.cmake +++ b/c/cmake_modules/DefineOptions.cmake @@ -236,6 +236,7 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ADBC_DRIVER_POSTGRESQL "Build the PostgreSQL driver" OFF) define_option(ADBC_DRIVER_SQLITE "Build the SQLite driver" OFF) define_option(ADBC_DRIVER_SNOWFLAKE "Build the Snowflake driver" OFF) + define_option(ADBC_DRIVER_BIGQUERY "Build the BigQuery driver" OFF) define_option(ADBC_INTEGRATION_DUCKDB "Build the test suite for DuckDB" OFF) endif() diff --git a/c/driver/bigquery/CMakeLists.txt b/c/driver/bigquery/CMakeLists.txt new file mode 100644 index 0000000000..fe3937878a --- /dev/null +++ b/c/driver/bigquery/CMakeLists.txt @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include(GoUtils) + +set(LDFLAGS "$<$:-s> $<$:-w>") +add_go_lib("${REPOSITORY_ROOT}/go/adbc/pkg/bigquery/" + adbc_driver_bigquery + SOURCES + driver.go + utils.h + utils.c + BUILD_TAGS + driverlib + PKG_CONFIG_NAME + adbc-driver-bigquery + SHARED_LINK_FLAGS + ${LDFLAGS} + OUTPUTS + ADBC_LIBRARIES) + +foreach(LIB_TARGET ${ADBC_LIBRARIES}) + target_include_directories(${LIB_TARGET} SYSTEM + INTERFACE ${REPOSITORY_ROOT} ${REPOSITORY_ROOT}/c/ + ${REPOSITORY_ROOT}/c/vendor + ${REPOSITORY_ROOT}/c/driver) +endforeach() + +if(ADBC_TEST_LINKAGE STREQUAL "shared") + set(TEST_LINK_LIBS adbc_driver_bigquery_shared) +else() + set(TEST_LINK_LIBS adbc_driver_bigquery_static) +endif() + +if(ADBC_BUILD_TESTS) + add_test_case(driver_bigquery_test + PREFIX + adbc + EXTRA_LABELS + driver-bigquery + SOURCES + bigquery_test.cc + EXTRA_LINK_LIBS + adbc_driver_common + adbc_validation + nanoarrow + ${TEST_LINK_LIBS}) + target_compile_features(adbc-driver-bigquery-test PRIVATE cxx_std_17) + target_include_directories(adbc-driver-bigquery-test SYSTEM + PRIVATE ${REPOSITORY_ROOT}/c/ + ${REPOSITORY_ROOT}/c/include/ + ${REPOSITORY_ROOT}/c/vendor + ${REPOSITORY_ROOT}/c/driver + ${REPOSITORY_ROOT}/c/driver/common) + adbc_configure_target(adbc-driver-bigquery-test) +endif() diff --git a/c/driver/bigquery/README.md b/c/driver/bigquery/README.md new file mode 100644 index 0000000000..6784787367 --- /dev/null +++ b/c/driver/bigquery/README.md @@ -0,0 +1,65 @@ + + +# ADBC Snowflake Driver + +This driver provides an interface to +[BigQuery](https://cloud.google.com/bigquery) using ADBC. + +## Building + +See [CONTRIBUTING.md](../../CONTRIBUTING.md) for details. + +## Testing + +BigQuery credentials and project ID are required. + +### Environment Variables +#### Project ID +Set `BIGQUERY_PROJECT_ID` to the project ID. + +#### Authentication +Set either one following environment variables for authentication: + +##### BIGQUERY_JSON_CREDENTIAL_FILE +Path to the JSON credential file. This file can be generated using `gcloud`: + +```sh +gcloud auth application-default login +``` + +And the default location of the generated JSON credential file is located at + +```sh +$HOME/.config/gcloud/application_default_credentials.json +``` + +##### BIGQUERY_JSON_CREDENTIAL_STRING +Store the whole JSON credential content, something like + +```json +{ + "account": "", + "client_id": "123456789012-1234567890abcdefabcdefabcdefabcd.apps.googleusercontent.com", + "client_secret": "d-SECRETSECRETSECRETSECR", + "refresh_token": "1//1234567890abcdefabcdefabcdef-abcdefabcd-abcdefabcdefabcdefabcdefab-abcdefabcdefabcdefabcdefabcdef-ab", + "type": "authorized_user", + "universe_domain": "googleapis.com" +} +``` diff --git a/c/driver/bigquery/adbc-driver-bigquery.pc.in b/c/driver/bigquery/adbc-driver-bigquery.pc.in new file mode 100644 index 0000000000..dfbd790964 --- /dev/null +++ b/c/driver/bigquery/adbc-driver-bigquery.pc.in @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +prefix=@CMAKE_INSTALL_PREFIX@ +libdir=@ADBC_PKG_CONFIG_LIBDIR@ + +Name: Apache Arrow Database Connectivity (ADBC) BigQuery driver +Description: The ADBC BigQuery driver provides an ADBC driver for BigQuery. +URL: https://github.com/apache/arrow-adbc +Version: @ADBC_VERSION@ +Libs: -L${libdir} -ladbc_driver_bigquery diff --git a/c/driver/bigquery/bigquery_test.cc b/c/driver/bigquery/bigquery_test.cc new file mode 100644 index 0000000000..b80f36336a --- /dev/null +++ b/c/driver/bigquery/bigquery_test.cc @@ -0,0 +1,363 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "validation/adbc_validation.h" +#include "validation/adbc_validation_util.h" + +using adbc_validation::IsOkStatus; + +#define CHECK_OK(EXPR) \ + do { \ + if (auto adbc_status = (EXPR); adbc_status != ADBC_STATUS_OK) { \ + return adbc_status; \ + } \ + } while (false) + +namespace { +std::string GetUuid() { + static std::random_device dev; + static std::mt19937 rng(dev()); + + std::uniform_int_distribution dist(0, 15); + + const char* v = "0123456789ABCDEF"; + const bool dash[] = {0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0}; + + std::string res; + for (int i = 0; i < 16; i++) { + if (dash[i]) res += "-"; + res += v[dist(rng)]; + res += v[dist(rng)]; + } + return res; +} +} // namespace + +class BigQueryQuirks : public adbc_validation::DriverQuirks { + public: + BigQueryQuirks() { + auth_value_ = std::getenv("BIGQUERY_JSON_CREDENTIAL_STRING"); + if (auth_value_ == nullptr || std::strlen(auth_value_) == 0) { + auth_value_ = std::getenv("BIGQUERY_JSON_CREDENTIAL_FILE"); + if (auth_value_ == nullptr || std::strlen(auth_value_) == 0) { + skip_ = true; + } else { + auth_type_ = "adbc.bigquery.sql.auth_type.json_credential_file"; + } + } else { + auth_type_ = "adbc.bigquery.sql.auth_type.json_credential_string"; + } + + catalog_name_ = std::getenv("BIGQUERY_PROJECT_ID"); + if (catalog_name_ == nullptr || std::strlen(catalog_name_) == 0) { + skip_ = true; + } + } + + AdbcStatusCode SetupDatabase(struct AdbcDatabase* database, + struct AdbcError* error) const override { + EXPECT_THAT( + AdbcDatabaseSetOption(database, "adbc.bigquery.sql.auth_type", auth_type_, error), + IsOkStatus(error)); + EXPECT_THAT(AdbcDatabaseSetOption(database, "adbc.bigquery.sql.auth_credentials", + auth_value_, error), + IsOkStatus(error)); + EXPECT_THAT(AdbcDatabaseSetOption(database, "adbc.bigquery.sql.project_id", + catalog_name_, error), + IsOkStatus(error)); + EXPECT_THAT(AdbcDatabaseSetOption(database, "adbc.bigquery.sql.dataset_id", + schema_.c_str(), error), + IsOkStatus(error)); + return ADBC_STATUS_OK; + } + + AdbcStatusCode DropTable(struct AdbcConnection* connection, const std::string& name, + struct AdbcError* error) const override { + adbc_validation::Handle statement; + CHECK_OK(AdbcStatementNew(connection, &statement.value, error)); + + std::string drop = "DROP TABLE IF EXISTS \""; + drop += name; + drop += "\""; + CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, drop.c_str(), error)); + CHECK_OK(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error)); + + CHECK_OK(AdbcStatementRelease(&statement.value, error)); + return ADBC_STATUS_OK; + } + + AdbcStatusCode CreateSampleTable(struct AdbcConnection* connection, + const std::string& name, + struct AdbcError* error) const override { + adbc_validation::Handle statement; + CHECK_OK(AdbcStatementNew(connection, &statement.value, error)); + + std::string create = "CREATE TABLE `ADBC_TESTING."; + create += name; + create += "` (int64s INT, strings TEXT)"; + CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), error)); + CHECK_OK(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error)); + // XXX: is there a better way to wait for BigQuery? (Why does 'CREATE + // TABLE' not wait for commit?) + std::this_thread::sleep_for(std::chrono::seconds(5)); + + std::string insert = "INSERT INTO `ADBC_TESTING."; + insert += name; + insert += "` VALUES (42, 'foo'), (-42, NULL), (NULL, '')"; + CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, insert.c_str(), error)); + CHECK_OK(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error)); + + CHECK_OK(AdbcStatementRelease(&statement.value, error)); + return ADBC_STATUS_OK; + } + + ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override { + switch (ingest_type) { + case NANOARROW_TYPE_INT8: + case NANOARROW_TYPE_UINT8: + case NANOARROW_TYPE_INT16: + case NANOARROW_TYPE_UINT16: + case NANOARROW_TYPE_INT32: + case NANOARROW_TYPE_UINT32: + case NANOARROW_TYPE_INT64: + case NANOARROW_TYPE_UINT64: + return NANOARROW_TYPE_INT64; + case NANOARROW_TYPE_FLOAT: + case NANOARROW_TYPE_DOUBLE: + return NANOARROW_TYPE_DOUBLE; + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: + return NANOARROW_TYPE_STRING; + default: + return ingest_type; + } + } + + std::string BindParameter(int index) const override { return "?"; } + bool supports_bulk_ingest(const char* /*mode*/) const override { return true; } + bool supports_concurrent_statements() const override { return true; } + bool supports_transactions() const override { return false; } + bool supports_get_sql_info() const override { return false; } + bool supports_get_objects() const override { return false; } + bool supports_metadata_current_catalog() const override { return false; } + bool supports_metadata_current_db_schema() const override { return false; } + bool supports_partitioned_data() const override { return false; } + bool supports_dynamic_parameter_binding() const override { return true; } + bool supports_error_on_incompatible_schema() const override { return false; } + bool ddl_implicit_commit_txn() const override { return true; } + std::string db_schema() const override { return schema_; } + + const char* auth_type_; + const char* auth_value_; + const char* catalog_name_; + bool skip_{false}; + std::string schema_; +}; + +class BigQueryTest : public ::testing::Test, public adbc_validation::DatabaseTest { + public: + const adbc_validation::DriverQuirks* quirks() const override { return &quirks_; } + void SetUp() override { + if (quirks_.skip_) { + GTEST_SKIP(); + } + ASSERT_NO_FATAL_FAILURE(SetUpTest()); + } + void TearDown() override { + if (!quirks_.skip_) { + ASSERT_NO_FATAL_FAILURE(TearDownTest()); + } + } + + protected: + BigQueryQuirks quirks_; +}; +ADBCV_TEST_DATABASE(BigQueryTest) + +class BigQueryConnectionTest : public ::testing::Test, + public adbc_validation::ConnectionTest { + public: + const adbc_validation::DriverQuirks* quirks() const override { return &quirks_; } + void SetUp() override { + if (quirks_.skip_) { + GTEST_SKIP(); + } + ASSERT_NO_FATAL_FAILURE(SetUpTest()); + } + void TearDown() override { + if (!quirks_.skip_) { + ASSERT_NO_FATAL_FAILURE(TearDownTest()); + } + } + + // Supported, but we don't validate the values + void TestMetadataCurrentCatalog() { GTEST_SKIP(); } + void TestMetadataCurrentDbSchema() { GTEST_SKIP(); } + + protected: + BigQueryQuirks quirks_; +}; +ADBCV_TEST_CONNECTION(BigQueryConnectionTest) + +class BigQueryStatementTest : public ::testing::Test, + public adbc_validation::StatementTest { + public: + const adbc_validation::DriverQuirks* quirks() const override { return &quirks_; } + + void SetUp() override { + if (quirks_.skip_) { + GTEST_SKIP(); + } + ASSERT_NO_FATAL_FAILURE(SetUpTest()); + } + void TearDown() override { + if (!quirks_.skip_) { + ASSERT_NO_FATAL_FAILURE(TearDownTest()); + } + } + + void TestSqlIngestInterval() { GTEST_SKIP(); } + void TestSqlIngestDuration() { GTEST_SKIP(); } + + void TestSqlIngestColumnEscaping() { GTEST_SKIP(); } + + public: + // will need to be updated to SetUpTestSuite when gtest is upgraded + static void SetUpTestCase() { + if (quirks_.skip_) { + GTEST_SKIP(); + } + + struct AdbcError error; + struct AdbcDatabase db; + struct AdbcConnection connection; + struct AdbcStatement statement; + + std::memset(&error, 0, sizeof(error)); + std::memset(&db, 0, sizeof(db)); + std::memset(&connection, 0, sizeof(connection)); + std::memset(&statement, 0, sizeof(statement)); + + ASSERT_THAT(AdbcDatabaseNew(&db, &error), IsOkStatus(&error)); + ASSERT_THAT(quirks_.SetupDatabase(&db, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcDatabaseInit(&db, &error), IsOkStatus(&error)); + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &db, &error), IsOkStatus(&error)); + + std::string schema_name = "ADBC_TESTING_" + GetUuid(); + std::string query = "CREATE SCHEMA `"; + query += quirks_.catalog_name_; + query += "." + schema_name + "`"; + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + quirks_.schema_ = schema_name; + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcDatabaseRelease(&db, &error), IsOkStatus(&error)); + } + + // will need to be updated to TearDownTestSuite when gtest is upgraded + static void TearDownTestCase() { + if (quirks_.skip_) { + GTEST_SKIP(); + } + + struct AdbcError error; + struct AdbcDatabase db; + struct AdbcConnection connection; + struct AdbcStatement statement; + + std::memset(&error, 0, sizeof(error)); + std::memset(&db, 0, sizeof(db)); + std::memset(&connection, 0, sizeof(connection)); + std::memset(&statement, 0, sizeof(statement)); + + ASSERT_THAT(AdbcDatabaseNew(&db, &error), IsOkStatus(&error)); + ASSERT_THAT(quirks_.SetupDatabase(&db, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcDatabaseInit(&db, &error), IsOkStatus(&error)); + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &db, &error), IsOkStatus(&error)); + + std::string query = "DROP SCHEMA `" + std::string(quirks_.catalog_name_) + "." + + quirks_.schema_ + "`"; + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcDatabaseRelease(&db, &error), IsOkStatus(&error)); + } + + protected: + void ValidateIngestedTemporalData(struct ArrowArrayView* values, ArrowType type, + enum ArrowTimeUnit unit, + const char* timezone) override { + switch (type) { + case NANOARROW_TYPE_TIMESTAMP: { + std::vector> expected; + switch (unit) { + case NANOARROW_TIME_UNIT_SECOND: + expected = {std::nullopt, -42, 0, 42}; + break; + case NANOARROW_TIME_UNIT_MILLI: + expected = {std::nullopt, -42, 0, 42}; + break; + case NANOARROW_TIME_UNIT_MICRO: + expected = {std::nullopt, -42, 0, 42}; + break; + case NANOARROW_TIME_UNIT_NANO: + expected = {std::nullopt, -42, 0, 42}; + break; + } + ASSERT_NO_FATAL_FAILURE( + adbc_validation::CompareArray(values, expected)); + break; + } + default: + FAIL() << "ValidateIngestedTemporalData not implemented for type " << type; + } + } + + static BigQueryQuirks quirks_; +}; + +BigQueryQuirks BigQueryStatementTest::quirks_; +ADBCV_TEST_STATEMENT(BigQueryStatementTest) diff --git a/c/driver/bigquery/meson.build b/c/driver/bigquery/meson.build new file mode 100644 index 0000000000..7eee6bf784 --- /dev/null +++ b/c/driver/bigquery/meson.build @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +golang = find_program('go') + +if build_machine.system() == 'windows' + prefix = '' + suffix = '.lib' +elif build_machine.system() == 'darwin' + prefix = 'lib' + suffix = '.dylib' +else + prefix = 'lib' + suffix = '.so' +endif + +adbc_driver_bigquery_name = prefix + 'adbc_driver_bigquery' + suffix +adbc_driver_bigquery_lib = custom_target( + 'adbc_driver_bigquery', + output: adbc_driver_bigquery_name, + command : [ + golang, + 'build', + '-C', + meson.project_source_root() + '/../go/adbc/pkg/bigquery', + '-tags=driverlib', + '-buildmode=c-shared', + '-o', + meson.current_build_dir() + '/' + adbc_driver_bigquery_name, + ], + install : true, + install_dir : '.', +) + +pkg.generate( + name: 'Apache Arrow Database Connectivity (ADBC) BigQuery driver', + description: 'The ADBC BigQuery driver provides an ADBC driver for BigQuery.', + libraries: [adbc_driver_bigquery_lib], + url: 'https://github.com/apache/arrow-adbc', + filebase: 'adbc-driver-bigquery', +) + +if get_option('tests') + exc = executable( + 'adbc-driver-bigquery-test', + 'bigquery_test.cc', + include_directories: [root_dir, driver_dir], + link_with: [ + adbc_common_lib, + adbc_driver_bigquery_lib + ], + dependencies: [adbc_validation_dep], + ) + test('adbc-driver-bigquery', exc) +endif diff --git a/c/driver/common/CMakeLists.txt b/c/driver/common/CMakeLists.txt index 74d57406a6..751eda3632 100644 --- a/c/driver/common/CMakeLists.txt +++ b/c/driver/common/CMakeLists.txt @@ -18,7 +18,7 @@ add_library(adbc_driver_common STATIC utils.c) adbc_configure_target(adbc_driver_common) set_target_properties(adbc_driver_common PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_include_directories(adbc_driver_common PRIVATE "${REPOSITORY_ROOT}" +target_include_directories(adbc_driver_common PRIVATE "${REPOSITORY_ROOT}/c/include" "${REPOSITORY_ROOT}/c/vendor") if(ADBC_BUILD_TESTS) @@ -29,12 +29,12 @@ if(ADBC_BUILD_TESTS) driver-common SOURCES utils_test.cc - driver_test.cc EXTRA_LINK_LIBS adbc_driver_common nanoarrow) target_compile_features(adbc-driver-common-test PRIVATE cxx_std_17) target_include_directories(adbc-driver-common-test - PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/vendor") + PRIVATE "${REPOSITORY_ROOT}/c/include" + "${REPOSITORY_ROOT}/c/vendor") adbc_configure_target(adbc-driver-common-test) endif() diff --git a/c/driver/common/driver_base.h b/c/driver/common/driver_base.h deleted file mode 100644 index 8f9fb7d074..0000000000 --- a/c/driver/common/driver_base.h +++ /dev/null @@ -1,770 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include -#include -#include -#include -#include - -#include - -// This file defines a developer-friendly way to create an ADBC driver, currently intended -// for testing the R driver manager. It handles errors, option getting/setting, and -// managing the export of the many C callables that compose an AdbcDriver. In general, -// functions or methods intended to be called from C are prefixed with "C" and are private -// (i.e., the public and protected methods are the only ones that driver authors should -// ever interact with). -// -// Example: -// class MyDatabase: public DatabaseObjectBase {}; -// class MyConnection: public ConnectionObjectBase {}; -// class MyStatement: public StatementObjectbase {}; -// AdbcStatusCode VoidDriverInitFunc(int version, void* raw_driver, AdbcError* error) { -// return Driver::Init( -// version, raw_driver, error); -// } - -namespace adbc { - -namespace common { - -class Error { - public: - explicit Error(std::string message) : message_(std::move(message)) { - std::memset(sql_state_, 0, sizeof(sql_state_)); - } - - explicit Error(const char* message) : Error(std::string(message)) {} - - Error(std::string message, std::vector> details) - : message_(std::move(message)), details_(std::move(details)) { - std::memset(sql_state_, 0, sizeof(sql_state_)); - } - - void AddDetail(std::string key, std::string value) { - details_.push_back({std::move(key), std::move(value)}); - } - - void ToAdbc(AdbcError* adbc_error, AdbcDriver* driver = nullptr) { - if (adbc_error == nullptr) { - return; - } - - if (adbc_error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { - auto error_owned_by_adbc_error = - new Error(std::move(message_), std::move(details_)); - adbc_error->message = - const_cast(error_owned_by_adbc_error->message_.c_str()); - adbc_error->private_data = error_owned_by_adbc_error; - adbc_error->private_driver = driver; - } else { - adbc_error->message = reinterpret_cast(std::malloc(message_.size() + 1)); - if (adbc_error->message != nullptr) { - std::memcpy(adbc_error->message, message_.c_str(), message_.size() + 1); - } - } - - std::memcpy(adbc_error->sqlstate, sql_state_, sizeof(sql_state_)); - adbc_error->release = &CRelease; - } - - private: - std::string message_; - std::vector> details_; - char sql_state_[5]; - - // Let the Driver use these to expose C callables wrapping option setters/getters - template - friend class Driver; - - int CDetailCount() const { return details_.size(); } - - AdbcErrorDetail CDetail(int index) const { - const auto& detail = details_[index]; - return {detail.first.c_str(), reinterpret_cast(detail.second.data()), - detail.second.size() + 1}; - } - - static void CRelease(AdbcError* error) { - if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { - auto error_obj = reinterpret_cast(error->private_data); - delete error_obj; - } else { - std::free(error->message); - } - - std::memset(error, 0, sizeof(AdbcError)); - } -}; - -// Variant that handles the option types that can be get/set by databases, -// connections, and statements. It currently does not attempt conversion -// (i.e., getting a double option as a string). -class Option { - public: - enum Type { TYPE_MISSING, TYPE_STRING, TYPE_BYTES, TYPE_INT, TYPE_DOUBLE }; - - Option() : type_(TYPE_MISSING) {} - explicit Option(const std::string& value) : type_(TYPE_STRING), value_string_(value) {} - explicit Option(const std::vector& value) - : type_(TYPE_BYTES), value_bytes_(value) {} - explicit Option(double value) : type_(TYPE_DOUBLE), value_double_(value) {} - explicit Option(int64_t value) : type_(TYPE_INT), value_int_(value) {} - - Type type() const { return type_; } - - const std::string& GetStringUnsafe() const { return value_string_; } - - const std::vector& GetBytesUnsafe() const { return value_bytes_; } - - int64_t GetIntUnsafe() const { return value_int_; } - - double GetDoubleUnsafe() const { return value_double_; } - - private: - Type type_; - std::string value_string_; - std::vector value_bytes_; - double value_double_; - int64_t value_int_; - - // Methods used by trampolines to export option values in C below - friend class ObjectBase; - - AdbcStatusCode CGet(char* out, size_t* length) const { - switch (type_) { - case TYPE_STRING: { - const std::string& value = GetStringUnsafe(); - size_t value_size_with_terminator = value.size() + 1; - if (*length < value_size_with_terminator) { - *length = value_size_with_terminator; - } else { - memcpy(out, value.data(), value_size_with_terminator); - } - - return ADBC_STATUS_OK; - } - default: - return ADBC_STATUS_NOT_FOUND; - } - } - - AdbcStatusCode CGet(uint8_t* out, size_t* length) const { - switch (type_) { - case TYPE_BYTES: { - const std::vector& value = GetBytesUnsafe(); - if (*length < value.size()) { - *length = value.size(); - } else { - memcpy(out, value.data(), value.size()); - } - - return ADBC_STATUS_OK; - } - default: - return ADBC_STATUS_NOT_FOUND; - } - } - - AdbcStatusCode CGet(int64_t* value) const { - switch (type_) { - case TYPE_INT: - *value = GetIntUnsafe(); - return ADBC_STATUS_OK; - default: - return ADBC_STATUS_NOT_FOUND; - } - } - - AdbcStatusCode CGet(double* value) const { - switch (type_) { - case TYPE_DOUBLE: - *value = GetDoubleUnsafe(); - return ADBC_STATUS_OK; - default: - return ADBC_STATUS_NOT_FOUND; - } - } -}; - -// Base class for private_data of AdbcDatabase, AdbcConnection, and AdbcStatement -// This class handles option setting and getting. -class ObjectBase { - public: - ObjectBase() : driver_(nullptr) {} - - virtual ~ObjectBase() {} - - // Driver authors can override this method to reject options that are not supported or - // that are set at a time not supported by the driver (e.g., to reject options that are - // set after Init() is called if this is not supported). - virtual AdbcStatusCode SetOption(const std::string& key, const Option& value) { - options_[key] = value; - return ADBC_STATUS_OK; - } - - // Called After zero or more SetOption() calls. The parent is the private_data of - // the AdbcDriver, AdbcDatabase, or AdbcConnection when initializing a subclass of - // DatabaseObjectBase, ConnectionObjectBase, and StatementObjectBase (respectively). - // For example, if you have defined Driver, - // you can reinterpret_cast(parent) in MyConnection::Init(). - virtual AdbcStatusCode Init(void* parent, AdbcError* error) { return ADBC_STATUS_OK; } - - // Called when the corresponding AdbcXXXRelease() function is invoked from C. - // Driver authors can override this method to return an error if the object is - // not in a valid state (e.g., if a connection has open statements) or to clean - // up resources when resource cleanup could fail. Resource cleanup that cannot fail - // (e.g., releasing memory) should generally be handled in the deleter. - virtual AdbcStatusCode Release(AdbcError* error) { return ADBC_STATUS_OK; } - - // Get an option that was previously set, providing an optional default value. - virtual const Option& GetOption(const std::string& key, - const Option& default_value = Option()) const { - auto result = options_.find(key); - if (result == options_.end()) { - return default_value; - } else { - return result->second; - } - } - - protected: - // Needed to export errors using Error::ToAdbc() that use 1.1.0 extensions - // (i.e., error details). This will be nullptr before Init() is called. - AdbcDriver* driver() const { return driver_; } - - private: - AdbcDriver* driver_; - std::unordered_map options_; - - // Let the Driver use these to expose C callables wrapping option setters/getters - template - friend class Driver; - - // The AdbcDriver* struct is set right before Init() is called by the Driver - // trampoline. - void set_driver(AdbcDriver* driver) { driver_ = driver; } - - template - AdbcStatusCode CSetOption(const char* key, T value, AdbcError* error) { - Option option(value); - return SetOption(key, option); - } - - AdbcStatusCode CSetOptionBytes(const char* key, const uint8_t* value, size_t length, - AdbcError* error) { - std::vector cppvalue(value, value + length); - Option option(cppvalue); - return SetOption(key, option); - } - - template - AdbcStatusCode CGetOptionStringLike(const char* key, T* value, size_t* length, - AdbcError* error) const { - Option result = GetOption(key); - if (result.type() == Option::TYPE_MISSING) { - InitErrorNotFound(key, error); - return ADBC_STATUS_NOT_FOUND; - } else { - AdbcStatusCode status = result.CGet(value, length); - if (status != ADBC_STATUS_OK) { - InitErrorWrongType(key, error); - } - - return status; - } - } - - template - AdbcStatusCode CGetOptionNumeric(const char* key, T* value, AdbcError* error) const { - Option result = GetOption(key); - if (result.type() == Option::TYPE_MISSING) { - InitErrorNotFound(key, error); - return ADBC_STATUS_NOT_FOUND; - } else { - AdbcStatusCode status = result.CGet(value); - if (status != ADBC_STATUS_OK) { - InitErrorWrongType(key, error); - } - - return status; - } - } - - void InitErrorNotFound(const char* key, AdbcError* error) const { - std::stringstream msg_builder; - msg_builder << "Option not found for key '" << key << "'"; - Error cpperror(msg_builder.str()); - cpperror.AddDetail("adbc.driver_base.option_key", key); - cpperror.ToAdbc(error, driver()); - } - - void InitErrorWrongType(const char* key, AdbcError* error) const { - std::stringstream msg_builder; - msg_builder << "Wrong type requested for option key '" << key << "'"; - Error cpperror(msg_builder.str()); - cpperror.AddDetail("adbc.driver_base.option_key", key); - cpperror.ToAdbc(error, driver()); - } -}; - -// Driver authors can subclass DatabaseObjectBase to track driver-specific -// state pertaining to the AdbcDatbase. The private_data member of an -// AdbcDatabase initialized by the driver will be a pointer to the -// subclass of DatbaseObjectBase. -class DatabaseObjectBase : public ObjectBase { - public: - // (there are no database functions other than option getting/setting) -}; - -// Driver authors can subclass ConnectionObjectBase to track driver-specific -// state pertaining to the AdbcConnection. The private_data member of an -// AdbcConnection initialized by the driver will be a pointer to the -// subclass of ConnectionObjectBase. Driver authors can override methods to -// implement the corresponding ConnectionXXX driver methods. -class ConnectionObjectBase : public ObjectBase { - public: - virtual AdbcStatusCode Commit(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } - - virtual AdbcStatusCode GetInfo(const uint32_t* info_codes, size_t info_codes_length, - ArrowArrayStream* out, AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode GetObjects(int depth, const char* catalog, const char* db_schema, - const char* table_name, const char** table_type, - const char* column_name, ArrowArrayStream* out, - AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema, - const char* table_name, ArrowSchema* schema, - AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode GetTableTypes(ArrowArrayStream* out, AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode ReadPartition(const uint8_t* serialized_partition, - size_t serialized_length, ArrowArrayStream* out, - AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode Rollback(AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode Cancel(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } - - virtual AdbcStatusCode GetStatistics(const char* catalog, const char* db_schema, - const char* table_name, char approximate, - ArrowArrayStream* out, AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode GetStatisticNames(ArrowArrayStream* out, AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } -}; - -// Driver authors can subclass StatementObjectBase to track driver-specific -// state pertaining to the AdbcStatement. The private_data member of an -// AdbcStatement initialized by the driver will be a pointer to the -// subclass of StatementObjectBase. Driver authors can override methods to -// implement the corresponding StatementXXX driver methods. -class StatementObjectBase : public ObjectBase { - public: - virtual AdbcStatusCode ExecuteQuery(ArrowArrayStream* stream, int64_t* rows_affected, - AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode ExecuteSchema(ArrowSchema* schema, AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode Prepare(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } - - virtual AdbcStatusCode SetSqlQuery(const char* query, AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode SetSubstraitPlan(const uint8_t* plan, size_t length, - AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode Bind(ArrowArray* values, ArrowSchema* schema, AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode BindStream(ArrowArrayStream* stream, AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode Cancel(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } -}; - -// Driver authors can declare a template specialization of the Driver class -// and use it to provide their driver init function. It is possible, but -// rarely useful, to subclass a driver. -template -class Driver { - public: - static AdbcStatusCode Init(int version, void* raw_driver, AdbcError* error) { - if (version != ADBC_VERSION_1_1_0) return ADBC_STATUS_NOT_IMPLEMENTED; - AdbcDriver* driver = (AdbcDriver*)raw_driver; - std::memset(driver, 0, sizeof(AdbcDriver)); - - // Driver lifecycle - driver->private_data = new Driver(); - driver->release = &CDriverRelease; - - // Driver functions - driver->ErrorGetDetailCount = &CErrorGetDetailCount; - driver->ErrorGetDetail = &CErrorGetDetail; - - // Database lifecycle - driver->DatabaseNew = &CNew; - driver->DatabaseInit = &CDatabaseInit; - driver->DatabaseRelease = &CRelease; - - // Database functions - driver->DatabaseSetOption = &CSetOption; - driver->DatabaseSetOptionBytes = &CSetOptionBytes; - driver->DatabaseSetOptionInt = &CSetOptionInt; - driver->DatabaseSetOptionDouble = &CSetOptionDouble; - driver->DatabaseGetOption = &CGetOption; - driver->DatabaseGetOptionBytes = &CGetOptionBytes; - driver->DatabaseGetOptionInt = &CGetOptionInt; - driver->DatabaseGetOptionDouble = &CGetOptionDouble; - - // Connection lifecycle - driver->ConnectionNew = &CNew; - driver->ConnectionInit = &CConnectionInit; - driver->ConnectionRelease = &CRelease; - - // Connection functions - driver->ConnectionSetOption = &CSetOption; - driver->ConnectionSetOptionBytes = &CSetOptionBytes; - driver->ConnectionSetOptionInt = &CSetOptionInt; - driver->ConnectionSetOptionDouble = &CSetOptionDouble; - driver->ConnectionGetOption = &CGetOption; - driver->ConnectionGetOptionBytes = &CGetOptionBytes; - driver->ConnectionGetOptionInt = &CGetOptionInt; - driver->ConnectionGetOptionDouble = &CGetOptionDouble; - driver->ConnectionCommit = &CConnectionCommit; - driver->ConnectionGetInfo = &CConnectionGetInfo; - driver->ConnectionGetObjects = &CConnectionGetObjects; - driver->ConnectionGetTableSchema = &CConnectionGetTableSchema; - driver->ConnectionGetTableTypes = &CConnectionGetTableTypes; - driver->ConnectionReadPartition = &CConnectionReadPartition; - driver->ConnectionRollback = &CConnectionRollback; - driver->ConnectionCancel = &CConnectionCancel; - driver->ConnectionGetStatistics = &CConnectionGetStatistics; - driver->ConnectionGetStatisticNames = &CConnectionGetStatisticNames; - - // Statement lifecycle - driver->StatementNew = &CStatementNew; - driver->StatementRelease = &CRelease; - - // Statement functions - driver->StatementSetOption = &CSetOption; - driver->StatementSetOptionBytes = &CSetOptionBytes; - driver->StatementSetOptionInt = &CSetOptionInt; - driver->StatementSetOptionDouble = &CSetOptionDouble; - driver->StatementGetOption = &CGetOption; - driver->StatementGetOptionBytes = &CGetOptionBytes; - driver->StatementGetOptionInt = &CGetOptionInt; - driver->StatementGetOptionDouble = &CGetOptionDouble; - - driver->StatementExecuteQuery = &CStatementExecuteQuery; - driver->StatementExecuteSchema = &CStatementExecuteSchema; - driver->StatementPrepare = &CStatementPrepare; - driver->StatementSetSqlQuery = &CStatementSetSqlQuery; - driver->StatementSetSubstraitPlan = &CStatementSetSubstraitPlan; - driver->StatementBind = &CStatementBind; - driver->StatementBindStream = &CStatementBindStream; - driver->StatementCancel = &CStatementCancel; - - return ADBC_STATUS_OK; - } - - private: - // Driver trampolines - static AdbcStatusCode CDriverRelease(AdbcDriver* driver, AdbcError* error) { - auto driver_private = reinterpret_cast(driver->private_data); - delete driver_private; - driver->private_data = nullptr; - return ADBC_STATUS_OK; - } - - static int CErrorGetDetailCount(const AdbcError* error) { - if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { - return 0; - } - - auto error_obj = reinterpret_cast(error->private_data); - return error_obj->CDetailCount(); - } - - static AdbcErrorDetail CErrorGetDetail(const AdbcError* error, int index) { - auto error_obj = reinterpret_cast(error->private_data); - return error_obj->CDetail(index); - } - - // Templatable trampolines - template - static AdbcStatusCode CNew(T* obj, AdbcError* error) { - auto private_data = new ObjectT(); - obj->private_data = private_data; - return ADBC_STATUS_OK; - } - - template - static AdbcStatusCode CRelease(T* obj, AdbcError* error) { - auto private_data = reinterpret_cast(obj->private_data); - AdbcStatusCode result = private_data->Release(error); - if (result != ADBC_STATUS_OK) { - return result; - } - - delete private_data; - obj->private_data = nullptr; - return ADBC_STATUS_OK; - } - - template - static AdbcStatusCode CSetOption(T* obj, const char* key, const char* value, - AdbcError* error) { - auto private_data = reinterpret_cast(obj->private_data); - return private_data->template CSetOption<>(key, value, error); - } - - template - static AdbcStatusCode CSetOptionBytes(T* obj, const char* key, const uint8_t* value, - size_t length, AdbcError* error) { - auto private_data = reinterpret_cast(obj->private_data); - return private_data->CSetOptionBytes(key, value, length, error); - } - - template - static AdbcStatusCode CSetOptionInt(T* obj, const char* key, int64_t value, - AdbcError* error) { - auto private_data = reinterpret_cast(obj->private_data); - return private_data->template CSetOption<>(key, value, error); - } - - template - static AdbcStatusCode CSetOptionDouble(T* obj, const char* key, double value, - AdbcError* error) { - auto private_data = reinterpret_cast(obj->private_data); - return private_data->template CSetOption<>(key, value, error); - } - - template - static AdbcStatusCode CGetOption(T* obj, const char* key, char* value, size_t* length, - AdbcError* error) { - auto private_data = reinterpret_cast(obj->private_data); - return private_data->template CGetOptionStringLike<>(key, value, length, error); - } - - template - static AdbcStatusCode CGetOptionBytes(T* obj, const char* key, uint8_t* value, - size_t* length, AdbcError* error) { - auto private_data = reinterpret_cast(obj->private_data); - return private_data->template CGetOptionStringLike<>(key, value, length, error); - } - - template - static AdbcStatusCode CGetOptionInt(T* obj, const char* key, int64_t* value, - AdbcError* error) { - auto private_data = reinterpret_cast(obj->private_data); - return private_data->template CGetOptionNumeric<>(key, value, error); - } - - template - static AdbcStatusCode CGetOptionDouble(T* obj, const char* key, double* value, - AdbcError* error) { - auto private_data = reinterpret_cast(obj->private_data); - return private_data->template CGetOptionNumeric<>(key, value, error); - } - - // Database trampolines - static AdbcStatusCode CDatabaseInit(AdbcDatabase* database, AdbcError* error) { - auto private_data = reinterpret_cast(database->private_data); - private_data->set_driver(database->private_driver); - return private_data->Init(database->private_driver->private_data, error); - } - - // Connection trampolines - static AdbcStatusCode CConnectionInit(AdbcConnection* connection, - AdbcDatabase* database, AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - private_data->set_driver(connection->private_driver); - return private_data->Init(database->private_data, error); - } - - static AdbcStatusCode CConnectionCancel(AdbcConnection* connection, AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - return private_data->Cancel(error); - } - - static AdbcStatusCode CConnectionGetInfo(AdbcConnection* connection, - const uint32_t* info_codes, - size_t info_codes_length, - ArrowArrayStream* out, AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - return private_data->GetInfo(info_codes, info_codes_length, out, error); - } - - static AdbcStatusCode CConnectionGetObjects(AdbcConnection* connection, int depth, - const char* catalog, const char* db_schema, - const char* table_name, - const char** table_type, - const char* column_name, - ArrowArrayStream* out, AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - return private_data->GetObjects(depth, catalog, db_schema, table_name, table_type, - column_name, out, error); - } - - static AdbcStatusCode CConnectionGetStatistics( - AdbcConnection* connection, const char* catalog, const char* db_schema, - const char* table_name, char approximate, ArrowArrayStream* out, AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - return private_data->GetStatistics(catalog, db_schema, table_name, approximate, out, - error); - } - - static AdbcStatusCode CConnectionGetStatisticNames(AdbcConnection* connection, - ArrowArrayStream* out, - AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - return private_data->GetStatisticNames(out, error); - } - - static AdbcStatusCode CConnectionGetTableSchema(AdbcConnection* connection, - const char* catalog, - const char* db_schema, - const char* table_name, - ArrowSchema* schema, AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - return private_data->GetTableSchema(catalog, db_schema, table_name, schema, error); - } - - static AdbcStatusCode CConnectionGetTableTypes(AdbcConnection* connection, - ArrowArrayStream* out, - AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - return private_data->GetTableTypes(out, error); - } - - static AdbcStatusCode CConnectionReadPartition(AdbcConnection* connection, - const uint8_t* serialized_partition, - size_t serialized_length, - ArrowArrayStream* out, - AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - return private_data->ReadPartition(serialized_partition, serialized_length, out, - error); - } - - static AdbcStatusCode CConnectionCommit(AdbcConnection* connection, AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - return private_data->Commit(error); - } - - static AdbcStatusCode CConnectionRollback(AdbcConnection* connection, - AdbcError* error) { - auto private_data = reinterpret_cast(connection->private_data); - return private_data->Rollback(error); - } - - // Statement trampolines - static AdbcStatusCode CStatementNew(AdbcConnection* connection, - AdbcStatement* statement, AdbcError* error) { - auto private_data = new StatementT(); - private_data->set_driver(connection->private_driver); - AdbcStatusCode status = private_data->Init(connection->private_data, error); - if (status != ADBC_STATUS_OK) { - delete private_data; - } - - statement->private_data = private_data; - return ADBC_STATUS_OK; - } - - static AdbcStatusCode CStatementExecuteQuery(AdbcStatement* statement, - ArrowArrayStream* stream, - int64_t* rows_affected, AdbcError* error) { - auto private_data = reinterpret_cast(statement->private_data); - return private_data->ExecuteQuery(stream, rows_affected, error); - } - - static AdbcStatusCode CStatementExecuteSchema(AdbcStatement* statement, - ArrowSchema* schema, AdbcError* error) { - auto private_data = reinterpret_cast(statement->private_data); - return private_data->ExecuteSchema(schema, error); - } - - static AdbcStatusCode CStatementPrepare(AdbcStatement* statement, AdbcError* error) { - auto private_data = reinterpret_cast(statement->private_data); - return private_data->Prepare(error); - } - - static AdbcStatusCode CStatementSetSqlQuery(AdbcStatement* statement, const char* query, - AdbcError* error) { - auto private_data = reinterpret_cast(statement->private_data); - return private_data->SetSqlQuery(query, error); - } - - static AdbcStatusCode CStatementSetSubstraitPlan(AdbcStatement* statement, - const uint8_t* plan, size_t length, - AdbcError* error) { - auto private_data = reinterpret_cast(statement->private_data); - return private_data->SetSubstraitPlan(plan, length, error); - } - - static AdbcStatusCode CStatementBind(AdbcStatement* statement, ArrowArray* values, - ArrowSchema* schema, AdbcError* error) { - auto private_data = reinterpret_cast(statement->private_data); - return private_data->Bind(values, schema, error); - } - - static AdbcStatusCode CStatementBindStream(AdbcStatement* statement, - ArrowArrayStream* stream, AdbcError* error) { - auto private_data = reinterpret_cast(statement->private_data); - return private_data->BindStream(stream, error); - } - - static AdbcStatusCode CStatementCancel(AdbcStatement* statement, AdbcError* error) { - auto private_data = reinterpret_cast(statement->private_data); - return private_data->Cancel(error); - } -}; - -} // namespace common - -} // namespace adbc diff --git a/c/driver/common/driver_test.cc b/c/driver/common/driver_test.cc deleted file mode 100644 index eaabc604a2..0000000000 --- a/c/driver/common/driver_test.cc +++ /dev/null @@ -1,271 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include - -#include -#include "driver_base.h" - -// Self-contained version of the Handle -static inline void clean_up(AdbcDriver* ptr) { ptr->release(ptr, nullptr); } - -static inline void clean_up(AdbcDatabase* ptr) { - ptr->private_driver->DatabaseRelease(ptr, nullptr); -} - -static inline void clean_up(AdbcConnection* ptr) { - ptr->private_driver->ConnectionRelease(ptr, nullptr); -} - -static inline void clean_up(AdbcStatement* ptr) { - ptr->private_driver->StatementRelease(ptr, nullptr); -} - -static inline void clean_up(AdbcError* ptr) { - if (ptr->release != nullptr) { - ptr->release(ptr); - } -} - -template -class Handle { - public: - explicit Handle(T* value) : value_(value) {} - - ~Handle() { clean_up(value_); } - - private: - T* value_; -}; - -class VoidDatabase : public adbc::common::DatabaseObjectBase {}; - -class VoidConnection : public adbc::common::ConnectionObjectBase {}; - -class VoidStatement : public adbc::common::StatementObjectBase {}; - -using VoidDriver = adbc::common::Driver; - -AdbcStatusCode VoidDriverInitFunc(int version, void* raw_driver, AdbcError* error) { - return VoidDriver::Init(version, raw_driver, error); -} - -TEST(TestDriverBase, TestVoidDriverOptions) { - // Test the get/set option implementation in the base driver - struct AdbcDriver driver; - memset(&driver, 0, sizeof(driver)); - ASSERT_EQ(VoidDriverInitFunc(ADBC_VERSION_1_1_0, &driver, nullptr), ADBC_STATUS_OK); - Handle driver_handle(&driver); - - struct AdbcDatabase database; - memset(&database, 0, sizeof(database)); - ASSERT_EQ(driver.DatabaseNew(&database, nullptr), ADBC_STATUS_OK); - database.private_driver = &driver; - Handle database_handle(&database); - ASSERT_EQ(driver.DatabaseInit(&database, nullptr), ADBC_STATUS_OK); - - std::vector opt_string; - std::vector opt_bytes; - size_t opt_size = 0; - int64_t opt_int = 0; - double opt_double = 0; - - // Check return codes without an error pointer for non-existent keys - ASSERT_EQ(driver.DatabaseGetOption(&database, "key_that_does_not_exist", nullptr, - &opt_size, nullptr), - ADBC_STATUS_NOT_FOUND); - ASSERT_EQ(driver.DatabaseGetOptionBytes(&database, "key_that_does_not_exist", nullptr, - &opt_size, nullptr), - ADBC_STATUS_NOT_FOUND); - ASSERT_EQ(driver.DatabaseGetOptionInt(&database, "key_that_does_not_exist", &opt_int, - nullptr), - ADBC_STATUS_NOT_FOUND); - ASSERT_EQ(driver.DatabaseGetOptionDouble(&database, "key_that_does_not_exist", - &opt_double, nullptr), - ADBC_STATUS_NOT_FOUND); - - // Check set/get for string - ASSERT_EQ(driver.DatabaseSetOption(&database, "key_string", "value_string", nullptr), - ADBC_STATUS_OK); - opt_size = 0; - ASSERT_EQ( - driver.DatabaseGetOption(&database, "key_string", nullptr, &opt_size, nullptr), - ADBC_STATUS_OK); - ASSERT_EQ(opt_size, strlen("value_string") + 1); - opt_string.resize(opt_size); - ASSERT_EQ(driver.DatabaseGetOption(&database, "key_string", opt_string.data(), - &opt_size, nullptr), - ADBC_STATUS_OK); - - // Check set/get for bytes - const uint8_t test_bytes[] = {0x01, 0x02, 0x03}; - ASSERT_EQ(driver.DatabaseSetOptionBytes(&database, "key_bytes", test_bytes, - sizeof(test_bytes), nullptr), - ADBC_STATUS_OK); - opt_size = 0; - ASSERT_EQ( - driver.DatabaseGetOptionBytes(&database, "key_bytes", nullptr, &opt_size, nullptr), - ADBC_STATUS_OK); - ASSERT_EQ(opt_size, sizeof(test_bytes)); - opt_bytes.resize(opt_size); - ASSERT_EQ(driver.DatabaseGetOptionBytes(&database, "key_bytes", opt_bytes.data(), - &opt_size, nullptr), - ADBC_STATUS_OK); - - // Check set/get for int - ASSERT_EQ(driver.DatabaseSetOptionInt(&database, "key_int", 1234, nullptr), - ADBC_STATUS_OK); - ASSERT_EQ(driver.DatabaseGetOptionInt(&database, "key_int", &opt_int, nullptr), - ADBC_STATUS_OK); - ASSERT_EQ(opt_int, 1234); - - // Check set/get for double - ASSERT_EQ(driver.DatabaseSetOptionDouble(&database, "key_double", 1234.5, nullptr), - ADBC_STATUS_OK); - ASSERT_EQ(driver.DatabaseGetOptionDouble(&database, "key_double", &opt_double, nullptr), - ADBC_STATUS_OK); - ASSERT_EQ(opt_double, 1234.5); - - // Check error code for getting a key of an incorrect type - opt_size = 0; - ASSERT_EQ(driver.DatabaseGetOption(&database, "key_bytes", nullptr, &opt_size, nullptr), - ADBC_STATUS_NOT_FOUND); - ASSERT_EQ( - driver.DatabaseGetOptionBytes(&database, "key_string", nullptr, &opt_size, nullptr), - ADBC_STATUS_NOT_FOUND); - ASSERT_EQ(driver.DatabaseGetOptionInt(&database, "key_bytes", &opt_int, nullptr), - ADBC_STATUS_NOT_FOUND); - ASSERT_EQ(driver.DatabaseGetOptionDouble(&database, "key_bytes", &opt_double, nullptr), - ADBC_STATUS_NOT_FOUND); -} - -TEST(TestDriverBase, TestVoidDriverError) { - // Test the extended error detail implementation in the base driver - struct AdbcDriver driver; - memset(&driver, 0, sizeof(driver)); - ASSERT_EQ(VoidDriverInitFunc(ADBC_VERSION_1_1_0, &driver, nullptr), ADBC_STATUS_OK); - Handle driver_handle(&driver); - - struct AdbcDatabase database; - memset(&database, 0, sizeof(database)); - ASSERT_EQ(driver.DatabaseNew(&database, nullptr), ADBC_STATUS_OK); - database.private_driver = &driver; - Handle database_handle(&database); - ASSERT_EQ(driver.DatabaseInit(&database, nullptr), ADBC_STATUS_OK); - - struct AdbcError error; - memset(&error, 0, sizeof(error)); - Handle error_handle(&error); - size_t opt_size = 0; - - // With zero-initialized error, should populate message but not details - ASSERT_EQ(driver.DatabaseGetOption(&database, "key_does_not_exist", nullptr, &opt_size, - &error), - ADBC_STATUS_NOT_FOUND); - EXPECT_EQ(error.vendor_code, 0); - EXPECT_STREQ(error.message, "Option not found for key 'key_does_not_exist'"); - EXPECT_EQ(error.private_data, nullptr); - EXPECT_EQ(error.private_driver, nullptr); - - // Release callback implementation should reset callback - error.release(&error); - ASSERT_EQ(error.release, nullptr); - - // With the vendor code pre-set, should populate a version with details - memset(&error, 0, sizeof(error)); - error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; - - ASSERT_EQ(driver.DatabaseGetOption(&database, "key_does_not_exist", nullptr, &opt_size, - &error), - ADBC_STATUS_NOT_FOUND); - EXPECT_NE(error.private_data, nullptr); - EXPECT_EQ(error.private_driver, &driver); - - ASSERT_EQ(error.private_driver->ErrorGetDetailCount(&error), 1); - - struct AdbcErrorDetail detail = error.private_driver->ErrorGetDetail(&error, 0); - ASSERT_STREQ(detail.key, "adbc.driver_base.option_key"); - ASSERT_EQ(detail.value_length, strlen("key_does_not_exist") + 1); - ASSERT_STREQ(reinterpret_cast(detail.value), "key_does_not_exist"); -} - -TEST(TestDriverBase, TestVoidDriverMethods) { - struct AdbcDriver driver; - memset(&driver, 0, sizeof(driver)); - ASSERT_EQ(VoidDriverInitFunc(ADBC_VERSION_1_1_0, &driver, nullptr), ADBC_STATUS_OK); - Handle driver_handle(&driver); - - // Database methods are only option related - struct AdbcDatabase database; - memset(&database, 0, sizeof(database)); - ASSERT_EQ(driver.DatabaseNew(&database, nullptr), ADBC_STATUS_OK); - database.private_driver = &driver; - Handle database_handle(&database); - ASSERT_EQ(driver.DatabaseInit(&database, nullptr), ADBC_STATUS_OK); - - // Test connection methods - struct AdbcConnection connection; - memset(&connection, 0, sizeof(connection)); - ASSERT_EQ(driver.ConnectionNew(&connection, nullptr), ADBC_STATUS_OK); - connection.private_driver = &driver; - Handle connection_handle(&connection); - ASSERT_EQ(driver.ConnectionInit(&connection, &database, nullptr), ADBC_STATUS_OK); - - EXPECT_EQ(driver.ConnectionCommit(&connection, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.ConnectionGetInfo(&connection, nullptr, 0, nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.ConnectionGetObjects(&connection, 0, nullptr, nullptr, 0, nullptr, - nullptr, nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.ConnectionGetTableSchema(&connection, nullptr, nullptr, nullptr, - nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.ConnectionGetTableTypes(&connection, nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.ConnectionReadPartition(&connection, nullptr, 0, nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.ConnectionRollback(&connection, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.ConnectionCancel(&connection, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.ConnectionGetStatistics(&connection, nullptr, nullptr, nullptr, 0, - nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.ConnectionGetStatisticNames(&connection, nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - - // Test statement methods - struct AdbcStatement statement; - memset(&statement, 0, sizeof(statement)); - ASSERT_EQ(driver.StatementNew(&connection, &statement, nullptr), ADBC_STATUS_OK); - statement.private_driver = &driver; - Handle statement_handle(&statement); - - EXPECT_EQ(driver.StatementExecuteQuery(&statement, nullptr, nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.StatementExecuteSchema(&statement, nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.StatementPrepare(&statement, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.StatementSetSqlQuery(&statement, nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.StatementSetSubstraitPlan(&statement, nullptr, 0, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.StatementBind(&statement, nullptr, nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.StatementBindStream(&statement, nullptr, nullptr), - ADBC_STATUS_NOT_IMPLEMENTED); - EXPECT_EQ(driver.StatementCancel(&statement, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); -} diff --git a/c/driver/common/meson.build b/c/driver/common/meson.build new file mode 100644 index 0000000000..b1423f0e58 --- /dev/null +++ b/c/driver/common/meson.build @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +adbc_common_lib = library( + 'adbc_driver_common', + sources: ['utils.c'], + include_directories: [include_dir], + dependencies: [nanoarrow_dep], + install: true, +) + +if get_option('tests') + exc = executable( + 'adbc-driver-common-test', + 'utils_test.cc', + include_directories: [include_dir], + link_with: [adbc_common_lib], + dependencies: [nanoarrow_dep, gtest_main_dep, gmock_dep], + ) + test('adbc-driver-common', exc) +endif diff --git a/c/driver/common/utils.c b/c/driver/common/utils.c index 795d79f973..00ebd51939 100644 --- a/c/driver/common/utils.c +++ b/c/driver/common/utils.c @@ -23,7 +23,7 @@ #include #include -#include +#include static size_t kErrorBufferSize = 1024; @@ -165,7 +165,7 @@ void AppendErrorDetail(struct AdbcError* error, const char* key, const uint8_t* return; } - size_t* new_lengths = calloc(new_capacity, sizeof(size_t*)); + size_t* new_lengths = calloc(new_capacity, sizeof(size_t)); if (!new_lengths) { free(new_keys); free(new_values); @@ -193,8 +193,10 @@ void AppendErrorDetail(struct AdbcError* error, const char* key, const uint8_t* details->capacity = new_capacity; } - char* key_data = strdup(key); + char* key_data = malloc(strlen(key) + 1); if (!key_data) return; + memcpy(key_data, key, strlen(key) + 1); + uint8_t* value_data = malloc(detail_length); if (!value_data) { free(key_data); @@ -233,70 +235,8 @@ struct AdbcErrorDetail CommonErrorGetDetail(const struct AdbcError* error, int i }; } -struct SingleBatchArrayStream { - struct ArrowSchema schema; - struct ArrowArray batch; -}; -static const char* SingleBatchArrayStreamGetLastError(struct ArrowArrayStream* stream) { - (void)stream; - return NULL; -} -static int SingleBatchArrayStreamGetNext(struct ArrowArrayStream* stream, - struct ArrowArray* batch) { - if (!stream || !stream->private_data) return EINVAL; - struct SingleBatchArrayStream* impl = - (struct SingleBatchArrayStream*)stream->private_data; - - memcpy(batch, &impl->batch, sizeof(*batch)); - memset(&impl->batch, 0, sizeof(*batch)); - return 0; -} -static int SingleBatchArrayStreamGetSchema(struct ArrowArrayStream* stream, - struct ArrowSchema* schema) { - if (!stream || !stream->private_data) return EINVAL; - struct SingleBatchArrayStream* impl = - (struct SingleBatchArrayStream*)stream->private_data; - - return ArrowSchemaDeepCopy(&impl->schema, schema); -} -static void SingleBatchArrayStreamRelease(struct ArrowArrayStream* stream) { - if (!stream || !stream->private_data) return; - struct SingleBatchArrayStream* impl = - (struct SingleBatchArrayStream*)stream->private_data; - impl->schema.release(&impl->schema); - if (impl->batch.release) impl->batch.release(&impl->batch); - free(impl); - - memset(stream, 0, sizeof(*stream)); -} - -AdbcStatusCode BatchToArrayStream(struct ArrowArray* values, struct ArrowSchema* schema, - struct ArrowArrayStream* stream, - struct AdbcError* error) { - if (!values->release) { - SetError(error, "ArrowArray is not initialized"); - return ADBC_STATUS_INTERNAL; - } else if (!schema->release) { - SetError(error, "ArrowSchema is not initialized"); - return ADBC_STATUS_INTERNAL; - } else if (stream->release) { - SetError(error, "ArrowArrayStream is already initialized"); - return ADBC_STATUS_INTERNAL; - } - - struct SingleBatchArrayStream* impl = - (struct SingleBatchArrayStream*)malloc(sizeof(*impl)); - memcpy(&impl->schema, schema, sizeof(*schema)); - memcpy(&impl->batch, values, sizeof(*values)); - memset(schema, 0, sizeof(*schema)); - memset(values, 0, sizeof(*values)); - stream->private_data = impl; - stream->get_last_error = SingleBatchArrayStreamGetLastError; - stream->get_next = SingleBatchArrayStreamGetNext; - stream->get_schema = SingleBatchArrayStreamGetSchema; - stream->release = SingleBatchArrayStreamRelease; - - return ADBC_STATUS_OK; +bool IsCommonError(const struct AdbcError* error) { + return error->release == ReleaseErrorWithDetails || error->release == ReleaseError; } int StringBuilderInit(struct StringBuilder* builder, size_t initial_size) { diff --git a/c/driver/common/utils.h b/c/driver/common/utils.h index cab5ddbe28..d204821b2b 100644 --- a/c/driver/common/utils.h +++ b/c/driver/common/utils.h @@ -22,7 +22,7 @@ #include #include -#include +#include #include "nanoarrow/nanoarrow.h" #ifdef __cplusplus @@ -53,6 +53,7 @@ void AppendErrorDetail(struct AdbcError* error, const char* key, const uint8_t* int CommonErrorGetDetailCount(const struct AdbcError* error); struct AdbcErrorDetail CommonErrorGetDetail(const struct AdbcError* error, int index); +bool IsCommonError(const struct AdbcError* error); struct StringBuilder { char* buffer; @@ -68,11 +69,6 @@ void StringBuilderReset(struct StringBuilder* builder); #undef ADBC_CHECK_PRINTF_ATTRIBUTE -/// Wrap a single batch as a stream. -AdbcStatusCode BatchToArrayStream(struct ArrowArray* values, struct ArrowSchema* schema, - struct ArrowArrayStream* stream, - struct AdbcError* error); - /// Check an NanoArrow status code. #define CHECK_NA(CODE, EXPR, ERROR) \ do { \ diff --git a/c/driver/flightsql/CMakeLists.txt b/c/driver/flightsql/CMakeLists.txt index a67df3a4d6..1101b82430 100644 --- a/c/driver/flightsql/CMakeLists.txt +++ b/c/driver/flightsql/CMakeLists.txt @@ -35,7 +35,8 @@ add_go_lib("${REPOSITORY_ROOT}/go/adbc/pkg/flightsql/" foreach(LIB_TARGET ${ADBC_LIBRARIES}) target_include_directories(${LIB_TARGET} SYSTEM - INTERFACE ${REPOSITORY_ROOT} ${REPOSITORY_ROOT}/c/ + INTERFACE ${REPOSITORY_ROOT}/c/ + ${REPOSITORY_ROOT}/c/include/ ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver) endforeach() @@ -62,7 +63,7 @@ if(ADBC_BUILD_TESTS) ${TEST_LINK_LIBS}) target_compile_features(adbc-driver-flightsql-test PRIVATE cxx_std_17) target_include_directories(adbc-driver-flightsql-test SYSTEM - PRIVATE ${REPOSITORY_ROOT} ${REPOSITORY_ROOT}/c/ + PRIVATE ${REPOSITORY_ROOT}/c/ ${REPOSITORY_ROOT}/c/include/ ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver) adbc_configure_target(adbc-driver-flightsql-test) diff --git a/c/driver/flightsql/dremio_flightsql_test.cc b/c/driver/flightsql/dremio_flightsql_test.cc index acc0682790..f18344017b 100644 --- a/c/driver/flightsql/dremio_flightsql_test.cc +++ b/c/driver/flightsql/dremio_flightsql_test.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include +#include #include #include #include @@ -30,13 +30,18 @@ class DremioFlightSqlQuirks : public adbc_validation::DriverQuirks { public: AdbcStatusCode SetupDatabase(struct AdbcDatabase* database, struct AdbcError* error) const override { - const char* uri = std::getenv("ADBC_DREMIO_FLIGHTSQL_URI"); - const char* user = std::getenv("ADBC_DREMIO_FLIGHTSQL_USER"); - const char* pass = std::getenv("ADBC_DREMIO_FLIGHTSQL_PASS"); - EXPECT_THAT(AdbcDatabaseSetOption(database, "uri", uri, error), IsOkStatus(error)); - EXPECT_THAT(AdbcDatabaseSetOption(database, "username", user, error), + const char* uri_raw = std::getenv("ADBC_DREMIO_FLIGHTSQL_URI"); + const char* user_raw = std::getenv("ADBC_DREMIO_FLIGHTSQL_USER"); + const char* pass_raw = std::getenv("ADBC_DREMIO_FLIGHTSQL_PASS"); + if (!uri_raw || !user_raw || !pass_raw) { + SetError(error, "Missing required environment variables"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + EXPECT_THAT(AdbcDatabaseSetOption(database, "uri", uri_raw, error), IsOkStatus(error)); - EXPECT_THAT(AdbcDatabaseSetOption(database, "password", pass, error), + EXPECT_THAT(AdbcDatabaseSetOption(database, "username", user_raw, error), + IsOkStatus(error)); + EXPECT_THAT(AdbcDatabaseSetOption(database, "password", pass_raw, error), IsOkStatus(error)); return ADBC_STATUS_OK; } diff --git a/c/driver/flightsql/meson.build b/c/driver/flightsql/meson.build new file mode 100644 index 0000000000..cac24d5cf8 --- /dev/null +++ b/c/driver/flightsql/meson.build @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +golang = find_program('go') + +if build_machine.system() == 'windows' + prefix = '' + suffix = '.lib' +elif build_machine.system() == 'darwin' + prefix = 'lib' + suffix = '.dylib' +else + prefix = 'lib' + suffix = '.so' +endif + +adbc_driver_flightsql_name = prefix + 'adbc_driver_flightsql' + suffix +adbc_driver_flightsql_lib = custom_target( + 'adbc_driver_flightsql', + output: adbc_driver_flightsql_name, + command : [ + golang, + 'build', + '-C', + meson.project_source_root() + '/../go/adbc/pkg/flightsql', + '-tags=driverlib', + '-buildmode=c-shared', + '-o', + meson.current_build_dir() + '/' + adbc_driver_flightsql_name, + ], + install : true, + install_dir : '.', +) + +pkg.generate( + name: 'Apache Arrow Database Connectivity (ADBC) Flight SQL driver', + description: 'The ADBC Flight SQL driver provides an ADBC driver for Flight SQL.', + url: 'https://github.com/apache/arrow-adbc', + libraries: [adbc_driver_flightsql_lib], + filebase: 'adbc-driver-flightsql', +) + +if get_option('tests') + exc = executable( + 'adbc-driver-flightsql-test', + 'dremio_flightsql_test.cc', + 'sqlite_flightsql_test.cc', + include_directories: [include_dir, c_dir, driver_dir], + link_with: [ + adbc_common_lib, + adbc_driver_flightsql_lib + ], + dependencies: [adbc_validation_dep], + ) + test('adbc-driver-flightsql', exc) +endif diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc b/c/driver/flightsql/sqlite_flightsql_test.cc index f08eeb884b..4797d58e77 100644 --- a/c/driver/flightsql/sqlite_flightsql_test.cc +++ b/c/driver/flightsql/sqlite_flightsql_test.cc @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include #include @@ -121,6 +121,7 @@ class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks { bool supports_get_objects() const override { return true; } bool supports_partitioned_data() const override { return true; } bool supports_dynamic_parameter_binding() const override { return true; } + std::string catalog() const override { return "main"; } }; class SqliteFlightSqlTest : public ::testing::Test, public adbc_validation::DatabaseTest { diff --git a/c/driver/framework/CMakeLists.txt b/c/driver/framework/CMakeLists.txt index d67e5cec03..f5c642b532 100644 --- a/c/driver/framework/CMakeLists.txt +++ b/c/driver/framework/CMakeLists.txt @@ -17,28 +17,29 @@ include(FetchContent) -add_library(adbc_driver_framework STATIC base_driver.cc catalog.cc objects.cc) +add_library(adbc_driver_framework STATIC objects.cc utility.cc) adbc_configure_target(adbc_driver_framework) set_target_properties(adbc_driver_framework PROPERTIES POSITION_INDEPENDENT_CODE ON) target_include_directories(adbc_driver_framework - PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/" + PRIVATE "${REPOSITORY_ROOT}/c/" "${REPOSITORY_ROOT}/c/include" "${REPOSITORY_ROOT}/c/vendor") target_link_libraries(adbc_driver_framework PUBLIC adbc_driver_common fmt::fmt) -# if(ADBC_BUILD_TESTS) -# add_test_case(driver_framework_test -# PREFIX -# adbc -# EXTRA_LABELS -# driver-framework -# SOURCES -# utils_test.cc -# driver_test.cc -# EXTRA_LINK_LIBS -# adbc_driver_framework -# nanoarrow) -# target_compile_features(adbc-driver-framework-test PRIVATE cxx_std_17) -# target_include_directories(adbc-driver-framework-test -# PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/vendor") -# adbc_configure_target(adbc-driver-framework-test) -# endif() +if(ADBC_BUILD_TESTS) + add_test_case(driver_framework_test + PREFIX + adbc + EXTRA_LABELS + driver-framework + SOURCES + base_driver_test.cc + EXTRA_LINK_LIBS + adbc_driver_framework + nanoarrow) + target_compile_features(adbc-driver-framework-test PRIVATE cxx_std_17) + target_include_directories(adbc-driver-framework-test + PRIVATE "${REPOSITORY_ROOT}/c/" + "${REPOSITORY_ROOT}/c/include" + "${REPOSITORY_ROOT}/c/vendor") + adbc_configure_target(adbc-driver-framework-test) +endif() diff --git a/c/driver/framework/base_driver.cc b/c/driver/framework/base_driver.cc deleted file mode 100644 index cebaae6288..0000000000 --- a/c/driver/framework/base_driver.cc +++ /dev/null @@ -1,160 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "driver/framework/base_driver.h" - -namespace adbc::driver { -Result Option::AsBool() const { - return std::visit( - [&](auto&& value) -> Result { - using T = std::decay_t; - if constexpr (std::is_same_v) { - if (value == ADBC_OPTION_VALUE_ENABLED) { - return true; - } else if (value == ADBC_OPTION_VALUE_DISABLED) { - return false; - } - } - return status::InvalidArgument("Invalid boolean value {}", *this); - }, - value_); -} - -Result Option::AsInt() const { - return std::visit( - [&](auto&& value) -> Result { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return value; - } else if constexpr (std::is_same_v) { - int64_t parsed = 0; - auto begin = value.data(); - auto end = value.data() + value.size(); - auto result = std::from_chars(begin, end, parsed); - if (result.ec != std::errc()) { - return status::InvalidArgument("Invalid integer value '{}': not an integer", - value); - } else if (result.ptr != end) { - return status::InvalidArgument("Invalid integer value '{}': trailing data", - value); - } - return parsed; - } - return status::InvalidArgument("Invalid integer value {}", *this); - }, - value_); -} - -Result Option::AsString() const { - return std::visit( - [&](auto&& value) -> Result { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return value; - } - return status::InvalidArgument("Invalid string value {}", *this); - }, - value_); -} - -AdbcStatusCode Option::CGet(char* out, size_t* length, AdbcError* error) const { - if (!out || !length) { - return status::InvalidArgument("Must provide both out and length to GetOption") - .ToAdbc(error); - } - return std::visit( - [&](auto&& value) -> AdbcStatusCode { - using T = std::decay_t; - if constexpr (std::is_same_v) { - size_t value_size_with_terminator = value.size() + 1; - if (*length >= value_size_with_terminator) { - std::memcpy(out, value.data(), value.size()); - out[value.size()] = 0; - } - *length = value_size_with_terminator; - return ADBC_STATUS_OK; - } else if constexpr (std::is_same_v) { - return status::NotFound("Unknown option").ToAdbc(error); - } else { - return status::NotFound("Option value is not a string").ToAdbc(error); - } - }, - value_); -} - -AdbcStatusCode Option::CGet(uint8_t* out, size_t* length, AdbcError* error) const { - if (!out || !length) { - return status::InvalidArgument("Must provide both out and length to GetOption") - .ToAdbc(error); - } - return std::visit( - [&](auto&& value) -> AdbcStatusCode { - using T = std::decay_t; - if constexpr (std::is_same_v || - std::is_same_v>) { - if (*length >= value.size()) { - std::memcpy(out, value.data(), value.size()); - } - *length = value.size(); - return ADBC_STATUS_OK; - } else if constexpr (std::is_same_v) { - return status::NotFound("Unknown option").ToAdbc(error); - } else { - return status::NotFound("Option value is not a bytestring").ToAdbc(error); - } - }, - value_); -} - -AdbcStatusCode Option::CGet(int64_t* out, AdbcError* error) const { - if (!out) { - return status::InvalidArgument("Must provide out to GetOption").ToAdbc(error); - } - return std::visit( - [&](auto&& value) -> AdbcStatusCode { - using T = std::decay_t; - if constexpr (std::is_same_v) { - *out = value; - return ADBC_STATUS_OK; - } else if constexpr (std::is_same_v) { - return status::NotFound("Unknown option").ToAdbc(error); - } else { - return status::NotFound("Option value is not an integer").ToAdbc(error); - } - }, - value_); -} - -AdbcStatusCode Option::CGet(double* out, AdbcError* error) const { - if (!out) { - return status::InvalidArgument("Must provide out to GetOption").ToAdbc(error); - } - return std::visit( - [&](auto&& value) -> AdbcStatusCode { - using T = std::decay_t; - if constexpr (std::is_same_v || std::is_same_v) { - *out = static_cast(value); - return ADBC_STATUS_OK; - } else if constexpr (std::is_same_v) { - return status::NotFound("Unknown option").ToAdbc(error); - } else { - return status::NotFound("Option value is not a double").ToAdbc(error); - } - }, - value_); -} -} // namespace adbc::driver diff --git a/c/driver/framework/base_driver.h b/c/driver/framework/base_driver.h index fcc385ac62..f379121b66 100644 --- a/c/driver/framework/base_driver.h +++ b/c/driver/framework/base_driver.h @@ -28,11 +28,8 @@ #include #include -#include -#include -#include +#include -#include "driver/common/utils.h" #include "driver/framework/status.h" /// \file base.h ADBC Driver Framework @@ -83,23 +80,173 @@ class Option { bool has_value() const { return !std::holds_alternative(value_); } /// \brief Try to parse a string value as a boolean. - Result AsBool() const; + Result AsBool() const { + return std::visit( + [&](auto&& value) -> Result { + using T = std::decay_t; + if constexpr (std::is_same_v) { + if (value == ADBC_OPTION_VALUE_ENABLED) { + return true; + } else if (value == ADBC_OPTION_VALUE_DISABLED) { + return false; + } + } + return status::InvalidArgument("Invalid boolean value ", this->Format()); + }, + value_); + } /// \brief Try to parse a string or integer value as an integer. - Result AsInt() const; + Result AsInt() const { + return std::visit( + [&](auto&& value) -> Result { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return value; + } else if constexpr (std::is_same_v) { + int64_t parsed = 0; + auto begin = value.data(); + auto end = value.data() + value.size(); + auto result = std::from_chars(begin, end, parsed); + if (result.ec != std::errc()) { + return status::InvalidArgument("Invalid integer value '", value, + "': not an integer", value); + } else if (result.ptr != end) { + return status::InvalidArgument("Invalid integer value '", value, + "': trailing data", value); + } + return parsed; + } else { + return status::InvalidArgument("Invalid integer value ", this->Format()); + } + }, + value_); + } /// \brief Get the value if it is a string. - Result AsString() const; + Result AsString() const { + return std::visit( + [&](auto&& value) -> Result { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return value; + } else { + return status::InvalidArgument("Invalid string value ", this->Format()); + } + }, + value_); + } + + /// \brief Provide a human-readable summary of the value + std::string Format() const { + return std::visit( + [&](auto&& value) -> std::string { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return "(NULL)"; + } else if constexpr (std::is_same_v) { + return std::string("'") + value + "'"; + } else if constexpr (std::is_same_v>) { + return std::string("(") + std::to_string(value.size()) + " bytes)"; + } else { + return std::to_string(value); + } + }, + value_); + } private: Value value_; // Methods used by trampolines to export option values in C below friend class ObjectBase; - AdbcStatusCode CGet(char* out, size_t* length, AdbcError* error) const; - AdbcStatusCode CGet(uint8_t* out, size_t* length, AdbcError* error) const; - AdbcStatusCode CGet(int64_t* out, AdbcError* error) const; - AdbcStatusCode CGet(double* out, AdbcError* error) const; + AdbcStatusCode CGet(char* out, size_t* length, AdbcError* error) const { + { + if (!length || (!out && *length > 0)) { + return status::InvalidArgument("Must provide both out and length to GetOption") + .ToAdbc(error); + } + return std::visit( + [&](auto&& value) -> AdbcStatusCode { + using T = std::decay_t; + if constexpr (std::is_same_v) { + size_t value_size_with_terminator = value.size() + 1; + if (*length >= value_size_with_terminator) { + std::memcpy(out, value.data(), value.size()); + out[value.size()] = 0; + } + *length = value_size_with_terminator; + return ADBC_STATUS_OK; + } else if constexpr (std::is_same_v) { + return status::NotFound("Unknown option").ToAdbc(error); + } else { + return status::NotFound("Option value is not a string").ToAdbc(error); + } + }, + value_); + } + } + AdbcStatusCode CGet(uint8_t* out, size_t* length, AdbcError* error) const { + if (!length || (!out && *length > 0)) { + return status::InvalidArgument("Must provide both out and length to GetOption") + .ToAdbc(error); + } + return std::visit( + [&](auto&& value) -> AdbcStatusCode { + using T = std::decay_t; + if constexpr (std::is_same_v || + std::is_same_v>) { + if (*length >= value.size()) { + std::memcpy(out, value.data(), value.size()); + } + *length = value.size(); + return ADBC_STATUS_OK; + } else if constexpr (std::is_same_v) { + return status::NotFound("Unknown option").ToAdbc(error); + } else { + return status::NotFound("Option value is not a bytestring").ToAdbc(error); + } + }, + value_); + } + AdbcStatusCode CGet(int64_t* out, AdbcError* error) const { + { + if (!out) { + return status::InvalidArgument("Must provide out to GetOption").ToAdbc(error); + } + return std::visit( + [&](auto&& value) -> AdbcStatusCode { + using T = std::decay_t; + if constexpr (std::is_same_v) { + *out = value; + return ADBC_STATUS_OK; + } else if constexpr (std::is_same_v) { + return status::NotFound("Unknown option").ToAdbc(error); + } else { + return status::NotFound("Option value is not an integer").ToAdbc(error); + } + }, + value_); + } + } + AdbcStatusCode CGet(double* out, AdbcError* error) const { + if (!out) { + return status::InvalidArgument("Must provide out to GetOption").ToAdbc(error); + } + return std::visit( + [&](auto&& value) -> AdbcStatusCode { + using T = std::decay_t; + if constexpr (std::is_same_v || std::is_same_v) { + *out = static_cast(value); + return ADBC_STATUS_OK; + } else if constexpr (std::is_same_v) { + return status::NotFound("Unknown option").ToAdbc(error); + } else { + return status::NotFound("Option value is not a double").ToAdbc(error); + } + }, + value_); + } }; /// \brief Base class for private_data of AdbcDatabase, AdbcConnection, and @@ -122,7 +269,7 @@ class ObjectBase { /// /// Called after 0 or more SetOption calls. Generally, you won't need to /// override this directly. Instead, use the typed InitImpl provided by - /// DatabaseBase/ConnectionBase/StatementBase. + /// Database/Connection/Statement. /// /// \param[in] parent A pointer to the AdbcDatabase or AdbcConnection /// implementation as appropriate, or nullptr. @@ -140,7 +287,7 @@ class ObjectBase { /// the destructor. /// /// Generally, you won't need to override this directly. Instead, use the - /// typed ReleaseImpl provided by DatabaseBase/ConnectionBase/StatementBase. + /// typed ReleaseImpl provided by Database/Connection/Statement. virtual AdbcStatusCode Release(AdbcError* error) { return ADBC_STATUS_OK; } /// \brief Get an option value. @@ -310,11 +457,22 @@ class Driver { } auto error_obj = reinterpret_cast(error->private_data); + if (!error_obj) { + return 0; + } return error_obj->CDetailCount(); } static AdbcErrorDetail CErrorGetDetail(const AdbcError* error, int index) { + if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + return {nullptr, nullptr, 0}; + } + auto error_obj = reinterpret_cast(error->private_data); + if (!error_obj) { + return {nullptr, nullptr, 0}; + } + return error_obj->CDetail(index); } @@ -622,27 +780,371 @@ class Driver { #undef CHECK_INIT }; -} // namespace adbc::driver +template +class BaseDatabase : public ObjectBase { + public: + using Base = BaseDatabase; -/// \brief Formatter for Option values. -template <> -struct fmt::formatter : fmt::nested_formatter { - auto format(const adbc::driver::Option& option, fmt::format_context& ctx) const { - return write_padded(ctx, [=](auto out) { - return std::visit( - [&](auto&& value) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return fmt::format_to(out, "(NULL)"); - } else if constexpr (std::is_same_v) { - return fmt::format_to(out, "'{}'", value); - } else if constexpr (std::is_same_v>) { - return fmt::format_to(out, "({} bytes)", value.size()); - } else { - return fmt::format_to(out, "{}", value); - } - }, - option.value()); - }); + BaseDatabase() : ObjectBase() {} + ~BaseDatabase() = default; + + /// \internal + AdbcStatusCode Init(void* parent, AdbcError* error) override { + RAISE_STATUS(error, impl().InitImpl()); + return ObjectBase::Init(parent, error); + } + + /// \internal + AdbcStatusCode Release(AdbcError* error) override { + RAISE_STATUS(error, impl().ReleaseImpl()); + return ADBC_STATUS_OK; + } + + /// \internal + AdbcStatusCode SetOption(std::string_view key, Option value, + AdbcError* error) override { + RAISE_STATUS(error, impl().SetOptionImpl(key, std::move(value))); + return ADBC_STATUS_OK; + } + + /// \brief Initialize the database. + virtual Status InitImpl() { return status::Ok(); } + + /// \brief Release the database. + virtual Status ReleaseImpl() { return status::Ok(); } + + /// \brief Set an option. May be called prior to InitImpl. + virtual Status SetOptionImpl(std::string_view key, Option value) { + return status::NotImplemented(Derived::kErrorPrefix, " Unknown database option ", key, + "=", value.Format()); + } + + private: + Derived& impl() { return static_cast(*this); } +}; + +template +class BaseConnection : public ObjectBase { + public: + using Base = BaseConnection; + + /// \brief Whether autocommit is enabled or not (by default: enabled). + enum class AutocommitState { + kAutocommit, + kTransaction, + }; + + BaseConnection() : ObjectBase() {} + ~BaseConnection() = default; + + /// \internal + AdbcStatusCode Init(void* parent, AdbcError* error) override { + RAISE_STATUS(error, impl().InitImpl(parent)); + return ObjectBase::Init(parent, error); + } + + /// \brief Initialize the database. + virtual Status InitImpl(void* parent) { return status::Ok(); } + + /// \internal + AdbcStatusCode Cancel(AdbcError* error) { return impl().CancelImpl().ToAdbc(error); } + + Status CancelImpl() { return status::NotImplemented("Cancel"); } + + /// \internal + AdbcStatusCode Commit(AdbcError* error) { return impl().CommitImpl().ToAdbc(error); } + + Status CommitImpl() { return status::NotImplemented("Commit"); } + + /// \internal + AdbcStatusCode GetInfo(const uint32_t* info_codes, size_t info_codes_length, + ArrowArrayStream* out, AdbcError* error) { + std::vector codes(info_codes, info_codes + info_codes_length); + RAISE_STATUS(error, impl().GetInfoImpl(codes, out)); + return ADBC_STATUS_OK; + } + + Status GetInfoImpl(const std::vector info_codes, ArrowArrayStream* out) { + return status::NotImplemented("GetInfo"); + } + + /// \internal + AdbcStatusCode GetObjects(int c_depth, const char* catalog, const char* db_schema, + const char* table_name, const char** table_type, + const char* column_name, ArrowArrayStream* out, + AdbcError* error) { + const auto catalog_filter = + catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt; + const auto schema_filter = + db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt; + const auto table_filter = + table_name ? std::make_optional(std::string_view(table_name)) : std::nullopt; + const auto column_filter = + column_name ? std::make_optional(std::string_view(column_name)) : std::nullopt; + std::vector table_type_filter; + while (table_type && *table_type) { + if (*table_type) { + table_type_filter.push_back(std::string_view(*table_type)); + } + table_type++; + } + + RAISE_STATUS( + error, impl().GetObjectsImpl(c_depth, catalog_filter, schema_filter, table_filter, + column_filter, table_type_filter, out)); + + return ADBC_STATUS_OK; + } + + Status GetObjectsImpl(int c_depth, std::optional catalog_filter, + std::optional schema_filter, + std::optional table_filter, + std::optional column_filter, + const std::vector& table_types, + struct ArrowArrayStream* out) { + return status::NotImplemented("GetObjects"); + } + + /// \internal + AdbcStatusCode GetStatistics(const char* catalog, const char* db_schema, + const char* table_name, char approximate, + ArrowArrayStream* out, AdbcError* error) { + const auto catalog_filter = + catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt; + const auto schema_filter = + db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt; + const auto table_filter = + table_name ? std::make_optional(std::string_view(table_name)) : std::nullopt; + RAISE_STATUS(error, impl().GetStatisticsImpl(catalog_filter, schema_filter, + table_filter, approximate != 0, out)); + return ADBC_STATUS_OK; + } + + Status GetStatisticsImpl(std::optional catalog, + std::optional db_schema, + std::optional table_name, bool approximate, + ArrowArrayStream* out) { + return status::NotImplemented("GetStatistics"); + } + + /// \internal + AdbcStatusCode GetStatisticNames(ArrowArrayStream* out, AdbcError* error) { + RAISE_STATUS(error, impl().GetStatisticNames(out)); + return ADBC_STATUS_OK; + } + + Status GetStatisticNames(ArrowArrayStream* out) { + return status::NotImplemented("GetStatisticNames"); + } + + /// \internal + AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema, + const char* table_name, ArrowSchema* schema, + AdbcError* error) { + if (!table_name) { + return status::InvalidArgument(Derived::kErrorPrefix, + " GetTableSchema: must provide table_name") + .ToAdbc(error); + } + + std::optional catalog_param = + catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt; + std::optional db_schema_param = + db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt; + + RAISE_STATUS(error, impl().GetTableSchemaImpl(catalog_param, db_schema_param, + table_name, schema)); + return ADBC_STATUS_OK; } + + Status GetTableSchemaImpl(std::optional catalog, + std::optional db_schema, + std::string_view table_name, ArrowSchema* out) { + return status::NotImplemented("GetTableSchema"); + } + + /// \internal + AdbcStatusCode GetTableTypes(ArrowArrayStream* out, AdbcError* error) { + RAISE_STATUS(error, impl().GetTableTypesImpl(out)); + return ADBC_STATUS_OK; + } + + Status GetTableTypesImpl(ArrowArrayStream* out) { + return status::NotImplemented("GetTableTypes"); + } + + /// \internal + AdbcStatusCode ReadPartition(const uint8_t* serialized_partition, + size_t serialized_length, ArrowArrayStream* out, + AdbcError* error) { + std::string_view partition(reinterpret_cast(serialized_partition), + serialized_length); + RAISE_STATUS(error, impl().ReadPartitionImpl(partition, out)); + return ADBC_STATUS_OK; + } + + Status ReadPartitionImpl(std::string_view serialized_partition, ArrowArrayStream* out) { + return status::NotImplemented("ReadPartition"); + } + + /// \internal + AdbcStatusCode Release(AdbcError* error) override { + RAISE_STATUS(error, impl().ReleaseImpl()); + return ADBC_STATUS_OK; + } + + Status ReleaseImpl() { return status::Ok(); } + + /// \internal + AdbcStatusCode Rollback(AdbcError* error) { + RAISE_STATUS(error, impl().RollbackImpl()); + return ADBC_STATUS_OK; + } + + Status RollbackImpl() { return status::NotImplemented("Rollback"); } + + /// \internal + AdbcStatusCode SetOption(std::string_view key, Option value, + AdbcError* error) override { + RAISE_STATUS(error, impl().SetOptionImpl(key, value)); + return ADBC_STATUS_OK; + } + + /// \brief Set an option. May be called prior to InitImpl. + virtual Status SetOptionImpl(std::string_view key, Option value) { + return status::NotImplemented(Derived::kErrorPrefix, " Unknown connection option ", + key, "=", value.Format()); + } + + private: + Derived& impl() { return static_cast(*this); } +}; + +template +class BaseStatement : public ObjectBase { + public: + using Base = BaseStatement; + + /// \internal + AdbcStatusCode Init(void* parent, AdbcError* error) override { + RAISE_STATUS(error, impl().InitImpl(parent)); + return ObjectBase::Init(parent, error); + } + + /// \brief Initialize the statement. + Status InitImpl(void* parent) { return status::Ok(); } + + /// \internal + AdbcStatusCode Release(AdbcError* error) override { + RAISE_STATUS(error, impl().ReleaseImpl()); + return ADBC_STATUS_OK; + } + + Status ReleaseImpl() { return status::Ok(); } + + /// \internal + AdbcStatusCode SetOption(std::string_view key, Option value, + AdbcError* error) override { + RAISE_STATUS(error, impl().SetOptionImpl(key, value)); + return ADBC_STATUS_OK; + } + + /// \brief Set an option. May be called prior to InitImpl. + virtual Status SetOptionImpl(std::string_view key, Option value) { + return status::NotImplemented(Derived::kErrorPrefix, " Unknown statement option ", + key, "=", value.Format()); + } + + AdbcStatusCode ExecuteQuery(ArrowArrayStream* stream, int64_t* rows_affected, + AdbcError* error) { + RAISE_RESULT(error, int64_t rows_affected_result, impl().ExecuteQueryImpl(stream)); + if (rows_affected) { + *rows_affected = rows_affected_result; + } + + return ADBC_STATUS_OK; + } + + Result ExecuteQueryImpl(ArrowArrayStream* stream) { + return status::NotImplemented("ExecuteQuery"); + } + + AdbcStatusCode ExecuteSchema(ArrowSchema* schema, AdbcError* error) { + RAISE_STATUS(error, impl().ExecuteSchemaImpl(schema)); + return ADBC_STATUS_OK; + } + + Status ExecuteSchemaImpl(ArrowSchema* schema) { + return status::NotImplemented("ExecuteSchema"); + } + + AdbcStatusCode Prepare(AdbcError* error) { + RAISE_STATUS(error, impl().PrepareImpl()); + return ADBC_STATUS_OK; + } + + Status PrepareImpl() { return status::NotImplemented("Prepare"); } + + AdbcStatusCode SetSqlQuery(const char* query, AdbcError* error) { + RAISE_STATUS(error, impl().SetSqlQueryImpl(query)); + return ADBC_STATUS_OK; + } + + Status SetSqlQueryImpl(std::string_view query) { + return status::NotImplemented("SetSqlQuery"); + } + + AdbcStatusCode SetSubstraitPlan(const uint8_t* plan, size_t length, AdbcError* error) { + RAISE_STATUS(error, impl().SetSubstraitPlanImpl(std::string_view( + reinterpret_cast(plan), length))); + return ADBC_STATUS_OK; + } + + Status SetSubstraitPlanImpl(std::string_view plan) { + return status::NotImplemented("SetSubstraitPlan"); + } + + AdbcStatusCode Bind(ArrowArray* values, ArrowSchema* schema, AdbcError* error) { + RAISE_STATUS(error, impl().BindImpl(values, schema)); + return ADBC_STATUS_OK; + } + + Status BindImpl(ArrowArray* values, ArrowSchema* schema) { + return status::NotImplemented("Bind"); + } + + AdbcStatusCode BindStream(ArrowArrayStream* stream, AdbcError* error) { + RAISE_STATUS(error, impl().BindStreamImpl(stream)); + return ADBC_STATUS_OK; + } + + Status BindStreamImpl(ArrowArrayStream* stream) { + return status::NotImplemented("BindStream"); + } + + AdbcStatusCode GetParameterSchema(ArrowSchema* schema, AdbcError* error) { + RAISE_STATUS(error, impl().GetParameterSchemaImpl(schema)); + return ADBC_STATUS_OK; + } + + Status GetParameterSchemaImpl(struct ArrowSchema* schema) { + return status::NotImplemented("GetParameterSchema"); + } + + AdbcStatusCode ExecutePartitions(ArrowSchema* schema, AdbcPartitions* partitions, + int64_t* rows_affected, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + AdbcStatusCode Cancel(AdbcError* error) { + RAISE_STATUS(error, impl().Cancel()); + return ADBC_STATUS_OK; + } + + Status Cancel() { return status::NotImplemented("Cancel"); } + + private: + Derived& impl() { return static_cast(*this); } }; + +} // namespace adbc::driver diff --git a/c/driver/framework/base_driver_test.cc b/c/driver/framework/base_driver_test.cc new file mode 100644 index 0000000000..1d8d61f60f --- /dev/null +++ b/c/driver/framework/base_driver_test.cc @@ -0,0 +1,237 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include +#include "driver/framework/base_driver.h" +#include "driver/framework/connection.h" +#include "driver/framework/database.h" +#include "driver/framework/statement.h" + +// Self-contained version of the Handle +static inline void clean_up(AdbcDriver* ptr) { ptr->release(ptr, nullptr); } + +static inline void clean_up(AdbcDatabase* ptr) { + ptr->private_driver->DatabaseRelease(ptr, nullptr); +} + +static inline void clean_up(AdbcConnection* ptr) { + ptr->private_driver->ConnectionRelease(ptr, nullptr); +} + +static inline void clean_up(AdbcStatement* ptr) { + ptr->private_driver->StatementRelease(ptr, nullptr); +} + +template +class Handle { + public: + explicit Handle(T* value) : value_(value) {} + + ~Handle() { clean_up(value_); } + + private: + T* value_; +}; + +namespace { + +class BaseVoidDatabase : public adbc::driver::BaseDatabase { + public: + [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[void]"; +}; + +class BaseVoidConnection : public adbc::driver::BaseConnection { + public: + [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[void]"; +}; + +class BaseVoidStatement : public adbc::driver::BaseStatement { + public: + [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[void]"; +}; + +using BaseVoidDriver = + adbc::driver::Driver; +} // namespace + +AdbcStatusCode BaseVoidDriverInitFunc(int version, void* raw_driver, AdbcError* error) { + return BaseVoidDriver::Init(version, raw_driver, error); +} + +TEST(TestDriverBase, TestBaseVoidDriverMethods) { + // Checks that wires are plugged in for a framework-based driver based only on what is + // available in base_driver.h + + struct AdbcDriver driver; + memset(&driver, 0, sizeof(driver)); + ASSERT_EQ(BaseVoidDriverInitFunc(ADBC_VERSION_1_1_0, &driver, nullptr), ADBC_STATUS_OK); + Handle driver_handle(&driver); + + // Database methods are only option related + struct AdbcDatabase database; + memset(&database, 0, sizeof(database)); + ASSERT_EQ(driver.DatabaseNew(&database, nullptr), ADBC_STATUS_OK); + database.private_driver = &driver; + Handle database_handle(&database); + ASSERT_EQ(driver.DatabaseInit(&database, nullptr), ADBC_STATUS_OK); + + // Test connection methods + struct AdbcConnection connection; + memset(&connection, 0, sizeof(connection)); + ASSERT_EQ(driver.ConnectionNew(&connection, nullptr), ADBC_STATUS_OK); + connection.private_driver = &driver; + Handle connection_handle(&connection); + ASSERT_EQ(driver.ConnectionInit(&connection, &database, nullptr), ADBC_STATUS_OK); + + EXPECT_EQ(driver.ConnectionCommit(&connection, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionGetInfo(&connection, nullptr, 0, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionGetObjects(&connection, 0, nullptr, nullptr, 0, nullptr, + nullptr, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionGetTableSchema(&connection, nullptr, nullptr, nullptr, + nullptr, nullptr), + ADBC_STATUS_INVALID_ARGUMENT); + EXPECT_EQ(driver.ConnectionGetTableTypes(&connection, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionReadPartition(&connection, nullptr, 0, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionRollback(&connection, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionCancel(&connection, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionGetStatistics(&connection, nullptr, nullptr, nullptr, 0, + nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionGetStatisticNames(&connection, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + + // Test statement methods + struct AdbcStatement statement; + memset(&statement, 0, sizeof(statement)); + ASSERT_EQ(driver.StatementNew(&connection, &statement, nullptr), ADBC_STATUS_OK); + statement.private_driver = &driver; + Handle statement_handle(&statement); + + EXPECT_EQ(driver.StatementExecuteQuery(&statement, nullptr, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.StatementExecuteSchema(&statement, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.StatementPrepare(&statement, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.StatementSetSqlQuery(&statement, "", nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.StatementSetSubstraitPlan(&statement, nullptr, 0, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.StatementBind(&statement, nullptr, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.StatementBindStream(&statement, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.StatementCancel(&statement, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); +} + +namespace { + +class VoidDatabase : public adbc::driver::Database { + public: + [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[void]"; +}; + +class VoidConnection : public adbc::driver::Connection { + public: + [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[void]"; +}; + +class VoidStatement : public adbc::driver::Statement { + public: + [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[void]"; +}; + +using VoidDriver = adbc::driver::Driver; +} // namespace + +AdbcStatusCode VoidDriverInitFunc(int version, void* raw_driver, AdbcError* error) { + return VoidDriver::Init(version, raw_driver, error); +} + +TEST(TestDriverBase, TestVoidDriverMethods) { + // Checks that wires are plugged in for a framework-based driver based on + // the more-batteries-included Database, Connection, and Statement + + struct AdbcDriver driver; + memset(&driver, 0, sizeof(driver)); + ASSERT_EQ(VoidDriverInitFunc(ADBC_VERSION_1_1_0, &driver, nullptr), ADBC_STATUS_OK); + Handle driver_handle(&driver); + + // Database methods are only option related + struct AdbcDatabase database; + memset(&database, 0, sizeof(database)); + ASSERT_EQ(driver.DatabaseNew(&database, nullptr), ADBC_STATUS_OK); + database.private_driver = &driver; + Handle database_handle(&database); + ASSERT_EQ(driver.DatabaseInit(&database, nullptr), ADBC_STATUS_OK); + + // Test connection methods + struct AdbcConnection connection; + memset(&connection, 0, sizeof(connection)); + ASSERT_EQ(driver.ConnectionNew(&connection, nullptr), ADBC_STATUS_OK); + connection.private_driver = &driver; + Handle connection_handle(&connection); + ASSERT_EQ(driver.ConnectionInit(&connection, &database, nullptr), ADBC_STATUS_OK); + + EXPECT_EQ(driver.ConnectionCommit(&connection, nullptr), ADBC_STATUS_INVALID_STATE); + EXPECT_EQ(driver.ConnectionGetInfo(&connection, nullptr, 0, nullptr, nullptr), + ADBC_STATUS_INVALID_ARGUMENT); + EXPECT_EQ(driver.ConnectionGetObjects(&connection, 0, nullptr, nullptr, 0, nullptr, + nullptr, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionGetTableSchema(&connection, nullptr, nullptr, nullptr, + nullptr, nullptr), + ADBC_STATUS_INVALID_ARGUMENT); + EXPECT_EQ(driver.ConnectionGetTableTypes(&connection, nullptr, nullptr), + ADBC_STATUS_INVALID_ARGUMENT); + EXPECT_EQ(driver.ConnectionReadPartition(&connection, nullptr, 0, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionRollback(&connection, nullptr), ADBC_STATUS_INVALID_STATE); + EXPECT_EQ(driver.ConnectionCancel(&connection, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionGetStatistics(&connection, nullptr, nullptr, nullptr, 0, + nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.ConnectionGetStatisticNames(&connection, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + + // Test statement methods + struct AdbcStatement statement; + memset(&statement, 0, sizeof(statement)); + ASSERT_EQ(driver.StatementNew(&connection, &statement, nullptr), ADBC_STATUS_OK); + statement.private_driver = &driver; + Handle statement_handle(&statement); + + EXPECT_EQ(driver.StatementExecuteQuery(&statement, nullptr, nullptr, nullptr), + ADBC_STATUS_INVALID_STATE); + EXPECT_EQ(driver.StatementExecuteSchema(&statement, nullptr, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.StatementPrepare(&statement, nullptr), ADBC_STATUS_INVALID_STATE); + EXPECT_EQ(driver.StatementSetSqlQuery(&statement, "", nullptr), ADBC_STATUS_OK); + EXPECT_EQ(driver.StatementSetSubstraitPlan(&statement, nullptr, 0, nullptr), + ADBC_STATUS_NOT_IMPLEMENTED); + EXPECT_EQ(driver.StatementBind(&statement, nullptr, nullptr, nullptr), + ADBC_STATUS_INVALID_ARGUMENT); + EXPECT_EQ(driver.StatementBindStream(&statement, nullptr, nullptr), + ADBC_STATUS_INVALID_ARGUMENT); + EXPECT_EQ(driver.StatementCancel(&statement, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); +} diff --git a/c/driver/framework/catalog.cc b/c/driver/framework/catalog.cc deleted file mode 100644 index 3860ebb9e0..0000000000 --- a/c/driver/framework/catalog.cc +++ /dev/null @@ -1,263 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "driver/framework/catalog.h" - -#include - -namespace adbc::driver { -Status AdbcInitConnectionGetInfoSchema(struct ArrowSchema* schema, - struct ArrowArray* array) { - ArrowSchemaInit(schema); - UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema, /*num_columns=*/2)); - - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_UINT32)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[0], "info_name")); - schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - - struct ArrowSchema* info_value = schema->children[1]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetTypeUnion(info_value, NANOARROW_TYPE_DENSE_UNION, 6)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value, "info_value")); - - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[0], "string_value")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[1], NANOARROW_TYPE_BOOL)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[1], "bool_value")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[2], NANOARROW_TYPE_INT64)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[2], "int64_value")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[3], NANOARROW_TYPE_INT32)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[3], "int32_bitmask")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[4], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[4], "string_list")); - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(info_value->children[5], NANOARROW_TYPE_MAP)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(info_value->children[5], "int32_to_int32_list_map")); - - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(info_value->children[4]->children[0], - NANOARROW_TYPE_STRING)); - - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[5]->children[0]->children[0], - NANOARROW_TYPE_INT32)); - info_value->children[5]->children[0]->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[5]->children[0]->children[1], - NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO( - Internal, - ArrowSchemaSetType(info_value->children[5]->children[0]->children[1]->children[0], - NANOARROW_TYPE_INT32)); - - struct ArrowError na_error = {0}; - UNWRAP_NANOARROW(na_error, Internal, - ArrowArrayInitFromSchema(array, schema, &na_error)); - UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array)); - - return status::Ok(); -} - -Status AdbcConnectionGetInfoAppendString(struct ArrowArray* array, uint32_t info_code, - std::string_view info_value) { - UNWRAP_ERRNO(Internal, ArrowArrayAppendUInt(array->children[0], info_code)); - // Append to type variant - struct ArrowStringView value; - value.data = info_value.data(); - value.size_bytes = static_cast(info_value.size()); - UNWRAP_ERRNO(Internal, ArrowArrayAppendString(array->children[1]->children[0], value)); - // Append type code/offset - UNWRAP_ERRNO(Internal, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/0)); - return status::Ok(); -} - -Status AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, uint32_t info_code, - int64_t info_value) { - UNWRAP_ERRNO(Internal, ArrowArrayAppendUInt(array->children[0], info_code)); - // Append to type variant - UNWRAP_ERRNO(Internal, - ArrowArrayAppendInt(array->children[1]->children[2], info_value)); - // Append type code/offset - UNWRAP_ERRNO(Internal, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/2)); - return status::Ok(); -} - -Status AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema) { - ArrowSchemaInit(schema); - UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema, /*num_columns=*/2)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[0], "catalog_name")); - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[1], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[1], "catalog_db_schemas")); - UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema->children[1]->children[0], 2)); - - struct ArrowSchema* db_schema_schema = schema->children[1]->children[0]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(db_schema_schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(db_schema_schema->children[0], "db_schema_name")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(db_schema_schema->children[1], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(db_schema_schema->children[1], "db_schema_tables")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetTypeStruct(db_schema_schema->children[1]->children[0], 4)); - - struct ArrowSchema* table_schema = db_schema_schema->children[1]->children[0]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(table_schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[0], "table_name")); - table_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(table_schema->children[1], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[1], "table_type")); - table_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(table_schema->children[2], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[2], "table_columns")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetTypeStruct(table_schema->children[2]->children[0], 19)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(table_schema->children[3], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(table_schema->children[3], "table_constraints")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetTypeStruct(table_schema->children[3]->children[0], 4)); - - struct ArrowSchema* column_schema = table_schema->children[2]->children[0]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[0], "column_name")); - column_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[1], NANOARROW_TYPE_INT32)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[1], "ordinal_position")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[2], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[2], "remarks")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[3], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[3], "xdbc_data_type")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[4], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[4], "xdbc_type_name")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[5], NANOARROW_TYPE_INT32)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[5], "xdbc_column_size")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[6], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[6], "xdbc_decimal_digits")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[7], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[7], "xdbc_num_prec_radix")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[8], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[8], "xdbc_nullable")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[9], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[9], "xdbc_column_def")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[10], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[10], "xdbc_sql_data_type")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[11], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[11], "xdbc_datetime_sub")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[12], NANOARROW_TYPE_INT32)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[12], "xdbc_char_octet_length")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[13], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[13], "xdbc_is_nullable")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[14], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[14], "xdbc_scope_catalog")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[15], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[15], "xdbc_scope_schema")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[16], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[16], "xdbc_scope_table")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[17], NANOARROW_TYPE_BOOL)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[17], "xdbc_is_autoincrement")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[18], NANOARROW_TYPE_BOOL)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[18], - "xdbc_is_generatedcolumn")); - - struct ArrowSchema* constraint_schema = table_schema->children[3]->children[0]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(constraint_schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(constraint_schema->children[0], "constraint_name")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(constraint_schema->children[1], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(constraint_schema->children[1], "constraint_type")); - constraint_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(constraint_schema->children[2], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(constraint_schema->children[2], - "constraint_column_names")); - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(constraint_schema->children[2]->children[0], - NANOARROW_TYPE_STRING)); - constraint_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(constraint_schema->children[3], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(constraint_schema->children[3], - "constraint_column_usage")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetTypeStruct(constraint_schema->children[3]->children[0], 4)); - - struct ArrowSchema* usage_schema = constraint_schema->children[3]->children[0]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(usage_schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[0], "fk_catalog")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(usage_schema->children[1], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[1], "fk_db_schema")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(usage_schema->children[2], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[2], "fk_table")); - usage_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(usage_schema->children[3], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[3], "fk_column_name")); - usage_schema->children[3]->flags &= ~ARROW_FLAG_NULLABLE; - - return status::Ok(); -} -} // namespace adbc::driver diff --git a/c/driver/framework/catalog.h b/c/driver/framework/catalog.h deleted file mode 100644 index a415765cac..0000000000 --- a/c/driver/framework/catalog.h +++ /dev/null @@ -1,154 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "driver/framework/status.h" - -namespace adbc::driver { - -/// \defgroup adbc-framework-catalog Catalog Utilities -/// Utilities for implementing catalog/metadata-related functions. -/// -/// @{ - -/// \brief The GetObjects level. -enum class GetObjectsDepth { - kCatalogs, - kSchemas, - kTables, - kColumns, -}; - -/// \brief Helper to implement GetObjects. -struct GetObjectsHelper { - virtual ~GetObjectsHelper() = default; - - struct Table { - std::string_view name; - std::string_view type; - }; - - struct ColumnXdbc { - std::optional xdbc_data_type; - std::optional xdbc_type_name; - std::optional xdbc_column_size; - std::optional xdbc_decimal_digits; - std::optional xdbc_num_prec_radix; - std::optional xdbc_nullable; - std::optional xdbc_column_def; - std::optional xdbc_sql_data_type; - std::optional xdbc_datetime_sub; - std::optional xdbc_char_octet_length; - std::optional xdbc_is_nullable; - std::optional xdbc_scope_catalog; - std::optional xdbc_scope_schema; - std::optional xdbc_scope_table; - std::optional xdbc_is_autoincrement; - std::optional xdbc_is_generatedcolumn; - }; - - struct Column { - std::string_view column_name; - int32_t ordinal_position; - std::optional remarks; - std::optional xdbc; - }; - - struct ConstraintUsage { - std::optional catalog; - std::optional schema; - std::string_view table; - std::string_view column; - }; - - struct Constraint { - std::optional name; - std::string_view type; - std::vector column_names; - std::optional> usage; - }; - - Status Close() { return status::Ok(); } - - /// \brief Fetch all metadata needed. The driver is free to delay loading - /// but this gives it a chance to load data up front. - virtual Status Load(GetObjectsDepth depth, - std::optional catalog_filter, - std::optional schema_filter, - std::optional table_filter, - std::optional column_filter, - const std::vector& table_types) { - return status::NotImplemented("GetObjects"); - } - - virtual Status LoadCatalogs() { - return status::NotImplemented("GetObjects at depth = catalog"); - }; - - virtual Result> NextCatalog() { return std::nullopt; } - - virtual Status LoadSchemas(std::string_view catalog) { - return status::NotImplemented("GetObjects at depth = schema"); - }; - - virtual Result> NextSchema() { return std::nullopt; } - - virtual Status LoadTables(std::string_view catalog, std::string_view schema) { - return status::NotImplemented("GetObjects at depth = table"); - }; - - virtual Result> NextTable() { return std::nullopt; } - - virtual Status LoadColumns(std::string_view catalog, std::string_view schema, - std::string_view table) { - return status::NotImplemented("GetObjects at depth = column"); - }; - - virtual Result> NextColumn() { return std::nullopt; } - - virtual Result> NextConstraint() { return std::nullopt; } -}; - -struct InfoValue { - uint32_t code; - std::variant value; - - explicit InfoValue(uint32_t code, std::variant value) - : code(code), value(std::move(value)) {} -}; - -Status AdbcInitConnectionGetInfoSchema(struct ArrowSchema* schema, - struct ArrowArray* array); -Status AdbcConnectionGetInfoAppendString(struct ArrowArray* array, uint32_t info_code, - std::string_view info_value); -Status AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, uint32_t info_code, - int64_t info_value); -Status AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema); -/// @} - -} // namespace adbc::driver diff --git a/c/driver/framework/base_connection.h b/c/driver/framework/connection.h similarity index 71% rename from c/driver/framework/base_connection.h rename to c/driver/framework/connection.h index f9ddc1c6a2..da3aae1070 100644 --- a/c/driver/framework/base_connection.h +++ b/c/driver/framework/connection.h @@ -24,15 +24,11 @@ #include #include -#include -#include -#include +#include -#include "driver/common/options.h" -#include "driver/common/utils.h" #include "driver/framework/base_driver.h" -#include "driver/framework/catalog.h" #include "driver/framework/objects.h" +#include "driver/framework/utility.h" namespace adbc::driver { /// \brief The CRTP base implementation of an AdbcConnection. @@ -43,9 +39,9 @@ namespace adbc::driver { /// define a constexpr static symbol called kErrorPrefix that is used to /// construct error messages. template -class ConnectionBase : public ObjectBase { +class Connection : public ObjectBase { public: - using Base = ConnectionBase; + using Base = Connection; /// \brief Whether autocommit is enabled or not (by default: enabled). enum class AutocommitState { @@ -53,8 +49,8 @@ class ConnectionBase : public ObjectBase { kTransaction, }; - ConnectionBase() : ObjectBase() {} - ~ConnectionBase() = default; + Connection() : ObjectBase() {} + ~Connection() = default; /// \internal AdbcStatusCode Init(void* parent, AdbcError* error) override { @@ -71,8 +67,8 @@ class ConnectionBase : public ObjectBase { AdbcStatusCode Commit(AdbcError* error) { switch (autocommit_) { case AutocommitState::kAutocommit: - return status::InvalidState("{} No active transaction, cannot commit", - Derived::kErrorPrefix) + return status::InvalidState(Derived::kErrorPrefix, + " No active transaction, cannot commit") .ToAdbc(error); case AutocommitState::kTransaction: return impl().CommitImpl().ToAdbc(error); @@ -84,36 +80,14 @@ class ConnectionBase : public ObjectBase { /// \internal AdbcStatusCode GetInfo(const uint32_t* info_codes, size_t info_codes_length, ArrowArrayStream* out, AdbcError* error) { - std::vector codes(info_codes, info_codes + info_codes_length); - RAISE_RESULT(error, auto infos, impl().InfoImpl(codes)); - - nanoarrow::UniqueSchema schema; - nanoarrow::UniqueArray array; - RAISE_STATUS(error, AdbcInitConnectionGetInfoSchema(schema.get(), array.get())); - - for (const auto& info : infos) { - RAISE_STATUS( - error, - std::visit( - [&](auto&& value) -> Status { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return AdbcConnectionGetInfoAppendString(array.get(), info.code, value); - } else if constexpr (std::is_same_v) { - return AdbcConnectionGetInfoAppendInt(array.get(), info.code, value); - } else { - static_assert(!sizeof(T), "info value type not implemented"); - } - return status::Ok(); - }, - info.value)); - CHECK_NA(INTERNAL, ArrowArrayFinishElement(array.get()), error); + if (!out) { + RAISE_STATUS(error, status::InvalidArgument("out must be non-null")); } - struct ArrowError na_error = {0}; - CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array.get(), &na_error), - &na_error, error); - return BatchToArrayStream(array.get(), schema.get(), out, error); + std::vector codes(info_codes, info_codes + info_codes_length); + RAISE_RESULT(error, auto infos, impl().InfoImpl(codes)); + RAISE_STATUS(error, MakeGetInfoStream(infos, out)); + return ADBC_STATUS_OK; } /// \internal @@ -152,20 +126,17 @@ class ConnectionBase : public ObjectBase { depth = GetObjectsDepth::kTables; break; default: - return status::InvalidArgument("{} GetObjects: invalid depth {}", - Derived::kErrorPrefix, c_depth) + return status::InvalidArgument(Derived::kErrorPrefix, + " GetObjects: invalid depth ", c_depth) .ToAdbc(error); } RAISE_RESULT(error, auto helper, impl().GetObjectsImpl()); - nanoarrow::UniqueSchema schema; - nanoarrow::UniqueArray array; - auto status = - BuildGetObjects(helper.get(), depth, catalog_filter, schema_filter, table_filter, - column_filter, table_type_filter, schema.get(), array.get()); + auto status = BuildGetObjects(helper.get(), depth, catalog_filter, schema_filter, + table_filter, column_filter, table_type_filter, out); RAISE_STATUS(error, helper->Close()); RAISE_STATUS(error, status); - return BatchToArrayStream(array.get(), schema.get(), out, error); + return ADBC_STATUS_OK; } /// \internal @@ -210,8 +181,8 @@ class ConnectionBase : public ObjectBase { const char* table_name, ArrowSchema* schema, AdbcError* error) { if (!table_name) { - return status::InvalidArgument("{} GetTableSchema: must provide table_name", - Derived::kErrorPrefix) + return status::InvalidArgument(Derived::kErrorPrefix, + " GetTableSchema: must provide table_name") .ToAdbc(error); } std::memset(schema, 0, sizeof(*schema)); @@ -228,36 +199,13 @@ class ConnectionBase : public ObjectBase { /// \internal AdbcStatusCode GetTableTypes(ArrowArrayStream* out, AdbcError* error) { - RAISE_RESULT(error, std::vector table_types, impl().GetTableTypesImpl()); - - nanoarrow::UniqueArray array; - nanoarrow::UniqueSchema schema; - ArrowSchemaInit(schema.get()); - - CHECK_NA(INTERNAL, ArrowSchemaSetType(schema.get(), NANOARROW_TYPE_STRUCT), error); - CHECK_NA(INTERNAL, ArrowSchemaAllocateChildren(schema.get(), /*num_columns=*/1), - error); - ArrowSchemaInit(schema.get()->children[0]); - CHECK_NA(INTERNAL, - ArrowSchemaSetType(schema.get()->children[0], NANOARROW_TYPE_STRING), error); - CHECK_NA(INTERNAL, ArrowSchemaSetName(schema.get()->children[0], "table_type"), - error); - schema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - - CHECK_NA(INTERNAL, ArrowArrayInitFromSchema(array.get(), schema.get(), NULL), error); - CHECK_NA(INTERNAL, ArrowArrayStartAppending(array.get()), error); - - for (std::string const& table_type : table_types) { - CHECK_NA( - INTERNAL, - ArrowArrayAppendString(array->children[0], ArrowCharView(table_type.c_str())), - error); - CHECK_NA(INTERNAL, ArrowArrayFinishElement(array.get()), error); + if (!out) { + RAISE_STATUS(error, status::InvalidArgument("out must be non-null")); } - CHECK_NA(INTERNAL, ArrowArrayFinishBuildingDefault(array.get(), NULL), error); - - return BatchToArrayStream(array.get(), schema.get(), out, error); + RAISE_RESULT(error, std::vector table_types, impl().GetTableTypesImpl()); + RAISE_STATUS(error, MakeTableTypesStream(table_types, out)); + return ADBC_STATUS_OK; } /// \internal @@ -276,8 +224,8 @@ class ConnectionBase : public ObjectBase { AdbcStatusCode Rollback(AdbcError* error) { switch (autocommit_) { case AutocommitState::kAutocommit: - return status::InvalidState("{} No active transaction, cannot rollback", - Derived::kErrorPrefix) + return status::InvalidState(Derived::kErrorPrefix, + " No active transaction, cannot rollback") .ToAdbc(error); case AutocommitState::kTransaction: return impl().RollbackImpl().ToAdbc(error); @@ -352,12 +300,12 @@ class ConnectionBase : public ObjectBase { } return status::Ok(); } - return status::NotImplemented("{} Unknown connection option {}={}", - Derived::kErrorPrefix, key, value); + return status::NotImplemented(Derived::kErrorPrefix, " Unknown connection option ", + key, "=", value.Format()); } Status ToggleAutocommitImpl(bool enable_autocommit) { - return status::NotImplemented("{} Cannot change autocommit", Derived::kErrorPrefix); + return status::NotImplemented(Derived::kErrorPrefix, " Cannot change autocommit"); } protected: diff --git a/c/driver/framework/base_database.h b/c/driver/framework/database.h similarity index 63% rename from c/driver/framework/base_database.h rename to c/driver/framework/database.h index 76901980c7..ff0dd756cb 100644 --- a/c/driver/framework/base_database.h +++ b/c/driver/framework/database.h @@ -20,7 +20,7 @@ #include #include -#include +#include #include "driver/framework/base_driver.h" #include "driver/framework/status.h" @@ -34,42 +34,22 @@ namespace adbc::driver { /// define a constexpr static symbol called kErrorPrefix that is used to /// construct error messages. template -class DatabaseBase : public ObjectBase { +class Database : public BaseDatabase { public: - using Base = DatabaseBase; + using Base = Database; - DatabaseBase() : ObjectBase() {} - ~DatabaseBase() = default; - - /// \internal - AdbcStatusCode Init(void* parent, AdbcError* error) override { - if (auto status = impl().InitImpl(); !status.ok()) { - return status.ToAdbc(error); - } - return ObjectBase::Init(parent, error); - } - - /// \internal - AdbcStatusCode Release(AdbcError* error) override { - return impl().ReleaseImpl().ToAdbc(error); - } - - /// \internal - AdbcStatusCode SetOption(std::string_view key, Option value, - AdbcError* error) override { - return impl().SetOptionImpl(key, std::move(value)).ToAdbc(error); - } + Database() : BaseDatabase() {} + ~Database() = default; /// \brief Initialize the database. - virtual Status InitImpl() { return status::Ok(); } + virtual Status InitImpl() { return BaseDatabase::InitImpl(); } /// \brief Release the database. - virtual Status ReleaseImpl() { return status::Ok(); } + virtual Status ReleaseImpl() { return BaseDatabase::ReleaseImpl(); } /// \brief Set an option. May be called prior to InitImpl. virtual Status SetOptionImpl(std::string_view key, Option value) { - return status::NotImplemented("{} Unknown database option {}={}", - Derived::kErrorPrefix, key, value); + return BaseDatabase::SetOptionImpl(key, value); } private: diff --git a/c/driver/framework/meson.build b/c/driver/framework/meson.build new file mode 100644 index 0000000000..08be53eacb --- /dev/null +++ b/c/driver/framework/meson.build @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +adbc_framework_lib = library( + 'adbc_driver_framework', + sources: [ + 'objects.cc', + 'utility.cc', + ], + include_directories: [include_dir, c_dir], + link_with: [adbc_common_lib], + dependencies: [nanoarrow_dep, fmt_dep], + install: true, +) diff --git a/c/driver/framework/objects.cc b/c/driver/framework/objects.cc index 67f5b26f01..691f6e4145 100644 --- a/c/driver/framework/objects.cc +++ b/c/driver/framework/objects.cc @@ -19,11 +19,174 @@ #include -#include "driver/framework/catalog.h" +#include "nanoarrow/nanoarrow.hpp" + #include "driver/framework/status.h" +#include "driver/framework/utility.h" namespace adbc::driver { +Status MakeGetObjectsSchema(struct ArrowSchema* schema) { + ArrowSchemaInit(schema); + UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema, /*num_columns=*/2)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[0], "catalog_name")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[1], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[1], "catalog_db_schemas")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema->children[1]->children[0], 2)); + + struct ArrowSchema* db_schema_schema = schema->children[1]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(db_schema_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(db_schema_schema->children[0], "db_schema_name")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(db_schema_schema->children[1], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(db_schema_schema->children[1], "db_schema_tables")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(db_schema_schema->children[1]->children[0], 4)); + + struct ArrowSchema* table_schema = db_schema_schema->children[1]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[0], "table_name")); + table_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[1], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[1], "table_type")); + table_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[2], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[2], "table_columns")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(table_schema->children[2]->children[0], 19)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[3], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(table_schema->children[3], "table_constraints")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(table_schema->children[3]->children[0], 4)); + + struct ArrowSchema* column_schema = table_schema->children[2]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[0], "column_name")); + column_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[1], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[1], "ordinal_position")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[2], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[2], "remarks")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[3], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[3], "xdbc_data_type")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[4], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[4], "xdbc_type_name")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[5], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[5], "xdbc_column_size")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[6], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[6], "xdbc_decimal_digits")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[7], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[7], "xdbc_num_prec_radix")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[8], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[8], "xdbc_nullable")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[9], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[9], "xdbc_column_def")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[10], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[10], "xdbc_sql_data_type")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[11], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[11], "xdbc_datetime_sub")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[12], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[12], "xdbc_char_octet_length")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[13], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[13], "xdbc_is_nullable")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[14], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[14], "xdbc_scope_catalog")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[15], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[15], "xdbc_scope_schema")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[16], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[16], "xdbc_scope_table")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[17], NANOARROW_TYPE_BOOL)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[17], "xdbc_is_autoincrement")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[18], NANOARROW_TYPE_BOOL)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[18], + "xdbc_is_generatedcolumn")); + + struct ArrowSchema* constraint_schema = table_schema->children[3]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(constraint_schema->children[0], "constraint_name")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[1], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(constraint_schema->children[1], "constraint_type")); + constraint_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[2], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(constraint_schema->children[2], + "constraint_column_names")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(constraint_schema->children[2]->children[0], + NANOARROW_TYPE_STRING)); + constraint_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[3], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(constraint_schema->children[3], + "constraint_column_usage")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(constraint_schema->children[3]->children[0], 4)); + + struct ArrowSchema* usage_schema = constraint_schema->children[3]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[0], "fk_catalog")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[1], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[1], "fk_db_schema")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[2], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[2], "fk_table")); + usage_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[3], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[3], "fk_column_name")); + usage_schema->children[3]->flags &= ~ARROW_FLAG_NULLABLE; + + return status::Ok(); +} + namespace { /// \brief A helper to convert std::string_view to Nanoarrow's ArrowStringView. ArrowStringView ToStringView(std::string_view s) { @@ -113,7 +276,7 @@ struct GetObjectsBuilder { private: Status InitArrowArray() { - UNWRAP_STATUS(AdbcInitConnectionObjectsSchema(schema)); + UNWRAP_STATUS(MakeGetObjectsSchema(schema)); UNWRAP_NANOARROW(na_error, Internal, ArrowArrayInitFromSchema(array, schema, &na_error)); UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array)); @@ -121,7 +284,7 @@ struct GetObjectsBuilder { } Status AppendCatalogs() { - UNWRAP_STATUS(helper->LoadCatalogs()); + UNWRAP_STATUS(helper->LoadCatalogs(catalog_filter)); while (true) { UNWRAP_RESULT(auto maybe_catalog, helper->NextCatalog()); if (!maybe_catalog.has_value()) break; @@ -139,7 +302,7 @@ struct GetObjectsBuilder { } Status AppendSchemas(std::string_view catalog) { - UNWRAP_STATUS(helper->LoadSchemas(catalog)); + UNWRAP_STATUS(helper->LoadSchemas(catalog, schema_filter)); while (true) { UNWRAP_RESULT(auto maybe_schema, helper->NextSchema()); if (!maybe_schema.has_value()) break; @@ -160,7 +323,7 @@ struct GetObjectsBuilder { } Status AppendTables(std::string_view catalog, std::string_view schema) { - UNWRAP_STATUS(helper->LoadTables(catalog, schema)); + UNWRAP_STATUS(helper->LoadTables(catalog, schema, table_filter, table_types)); while (true) { UNWRAP_RESULT(auto maybe_table, helper->NextTable()); if (!maybe_table.has_value()) break; @@ -185,7 +348,7 @@ struct GetObjectsBuilder { Status AppendColumns(std::string_view catalog, std::string_view schema, std::string_view table) { - UNWRAP_STATUS(helper->LoadColumns(catalog, schema, table)); + UNWRAP_STATUS(helper->LoadColumns(catalog, schema, table, column_filter)); while (true) { UNWRAP_RESULT(auto maybe_column, helper->NextColumn()); if (!maybe_column.has_value()) break; @@ -319,34 +482,34 @@ struct GetObjectsBuilder { std::optional table_filter; std::optional column_filter; const std::vector& table_types; - struct ArrowSchema* schema; - struct ArrowArray* array; + struct ArrowSchema* schema = nullptr; + struct ArrowArray* array = nullptr; struct ArrowError na_error; - struct ArrowArray* catalog_name_col; - struct ArrowArray* catalog_db_schemas_col; - struct ArrowArray* catalog_db_schemas_items; - struct ArrowArray* db_schema_name_col; - struct ArrowArray* db_schema_tables_col; - struct ArrowArray* schema_table_items; - struct ArrowArray* table_name_col; - struct ArrowArray* table_type_col; - struct ArrowArray* table_columns_col; - struct ArrowArray* table_columns_items; - struct ArrowArray* column_name_col; - struct ArrowArray* column_position_col; - struct ArrowArray* column_remarks_col; - struct ArrowArray* table_constraints_col; - struct ArrowArray* table_constraints_items; - struct ArrowArray* constraint_name_col; - struct ArrowArray* constraint_type_col; - struct ArrowArray* constraint_column_names_col; - struct ArrowArray* constraint_column_name_col; - struct ArrowArray* constraint_column_usages_col; - struct ArrowArray* constraint_column_usage_items; - struct ArrowArray* fk_catalog_col; - struct ArrowArray* fk_db_schema_col; - struct ArrowArray* fk_table_col; - struct ArrowArray* fk_column_name_col; + struct ArrowArray* catalog_name_col = nullptr; + struct ArrowArray* catalog_db_schemas_col = nullptr; + struct ArrowArray* catalog_db_schemas_items = nullptr; + struct ArrowArray* db_schema_name_col = nullptr; + struct ArrowArray* db_schema_tables_col = nullptr; + struct ArrowArray* schema_table_items = nullptr; + struct ArrowArray* table_name_col = nullptr; + struct ArrowArray* table_type_col = nullptr; + struct ArrowArray* table_columns_col = nullptr; + struct ArrowArray* table_columns_items = nullptr; + struct ArrowArray* column_name_col = nullptr; + struct ArrowArray* column_position_col = nullptr; + struct ArrowArray* column_remarks_col = nullptr; + struct ArrowArray* table_constraints_col = nullptr; + struct ArrowArray* table_constraints_items = nullptr; + struct ArrowArray* constraint_name_col = nullptr; + struct ArrowArray* constraint_type_col = nullptr; + struct ArrowArray* constraint_column_names_col = nullptr; + struct ArrowArray* constraint_column_name_col = nullptr; + struct ArrowArray* constraint_column_usages_col = nullptr; + struct ArrowArray* constraint_column_usage_items = nullptr; + struct ArrowArray* fk_catalog_col = nullptr; + struct ArrowArray* fk_db_schema_col = nullptr; + struct ArrowArray* fk_table_col = nullptr; + struct ArrowArray* fk_column_name_col = nullptr; }; } // namespace @@ -356,9 +519,14 @@ Status BuildGetObjects(GetObjectsHelper* helper, GetObjectsDepth depth, std::optional table_filter, std::optional column_filter, const std::vector& table_types, - struct ArrowSchema* schema, struct ArrowArray* array) { - return GetObjectsBuilder(helper, depth, catalog_filter, schema_filter, table_filter, - column_filter, table_types, schema, array) - .Build(); + struct ArrowArrayStream* out) { + nanoarrow::UniqueSchema schema; + nanoarrow::UniqueArray array; + UNWRAP_STATUS(GetObjectsBuilder(helper, depth, catalog_filter, schema_filter, + table_filter, column_filter, table_types, schema.get(), + array.get()) + .Build()); + MakeArrayStream(schema.get(), array.get(), out); + return status::Ok(); } } // namespace adbc::driver diff --git a/c/driver/framework/objects.h b/c/driver/framework/objects.h index 0d14c5c6b8..3e74e78835 100644 --- a/c/driver/framework/objects.h +++ b/c/driver/framework/objects.h @@ -21,20 +21,132 @@ #include #include -#include - -#include "driver/framework/catalog.h" +#include #include "driver/framework/status.h" #include "driver/framework/type_fwd.h" namespace adbc::driver { + +/// \defgroup adbc-framework-catalog Catalog Utilities +/// Utilities for implementing catalog/metadata-related functions. +/// +/// @{ + +/// \brief Create the ArrowSchema for AdbcConnectionGetObjects(). +Status MakeGetObjectsSchema(ArrowSchema* schema); + +/// \brief The GetObjects level. +enum class GetObjectsDepth { + kCatalogs, + kSchemas, + kTables, + kColumns, +}; + +/// \brief Helper to implement GetObjects. +/// +/// Drivers can implement methods of the GetObjectsHelper in a driver-specific +/// class to get a compliant implementation of AdbcConnectionGetObjects(). +struct GetObjectsHelper { + virtual ~GetObjectsHelper() = default; + + struct Table { + std::string_view name; + std::string_view type; + }; + + struct ColumnXdbc { + std::optional xdbc_data_type; + std::optional xdbc_type_name; + std::optional xdbc_column_size; + std::optional xdbc_decimal_digits; + std::optional xdbc_num_prec_radix; + std::optional xdbc_nullable; + std::optional xdbc_column_def; + std::optional xdbc_sql_data_type; + std::optional xdbc_datetime_sub; + std::optional xdbc_char_octet_length; + std::optional xdbc_is_nullable; + std::optional xdbc_scope_catalog; + std::optional xdbc_scope_schema; + std::optional xdbc_scope_table; + std::optional xdbc_is_autoincrement; + std::optional xdbc_is_generatedcolumn; + }; + + struct Column { + std::string_view column_name; + int32_t ordinal_position; + std::optional remarks; + std::optional xdbc; + }; + + struct ConstraintUsage { + std::optional catalog; + std::optional schema; + std::string_view table; + std::string_view column; + }; + + struct Constraint { + std::optional name; + std::string_view type; + std::vector column_names; + std::optional> usage; + }; + + Status Close() { return status::Ok(); } + + /// \brief Fetch all metadata needed. The driver is free to delay loading + /// but this gives it a chance to load data up front. + virtual Status Load(GetObjectsDepth depth, + std::optional catalog_filter, + std::optional schema_filter, + std::optional table_filter, + std::optional column_filter, + const std::vector& table_types) { + return status::NotImplemented("GetObjects"); + } + + virtual Status LoadCatalogs(std::optional catalog_filter) { + return status::NotImplemented("GetObjects at depth = catalog"); + }; + + virtual Result> NextCatalog() { return std::nullopt; } + + virtual Status LoadSchemas(std::string_view catalog, + std::optional schema_filter) { + return status::NotImplemented("GetObjects at depth = schema"); + }; + + virtual Result> NextSchema() { return std::nullopt; } + + virtual Status LoadTables(std::string_view catalog, std::string_view schema, + std::optional table_filter, + const std::vector& table_types) { + return status::NotImplemented("GetObjects at depth = table"); + }; + + virtual Result> NextTable() { return std::nullopt; } + + virtual Status LoadColumns(std::string_view catalog, std::string_view schema, + std::string_view table, + std::optional column_filter) { + return status::NotImplemented("GetObjects at depth = column"); + }; + + virtual Result> NextColumn() { return std::nullopt; } + + virtual Result> NextConstraint() { return std::nullopt; } +}; + /// \brief A helper that implements GetObjects. -/// The schema/array/helper lifetime are caller-managed. +/// The out/helper lifetime are caller-managed. Status BuildGetObjects(GetObjectsHelper* helper, GetObjectsDepth depth, std::optional catalog_filter, std::optional schema_filter, std::optional table_filter, std::optional column_filter, const std::vector& table_types, - struct ArrowSchema* schema, struct ArrowArray* array); + ArrowArrayStream* out); } // namespace adbc::driver diff --git a/c/driver/framework/base_statement.h b/c/driver/framework/statement.h similarity index 73% rename from c/driver/framework/base_statement.h rename to c/driver/framework/statement.h index 362ce2291f..c07324849c 100644 --- a/c/driver/framework/base_statement.h +++ b/c/driver/framework/statement.h @@ -25,48 +25,17 @@ #include #include -#include "driver/common/options.h" #include "driver/framework/base_driver.h" #include "driver/framework/status.h" +#include "driver/framework/utility.h" namespace adbc::driver { -/// One-value ArrowArrayStream used to unify the implementations of Bind -struct OneValueStream { - struct ArrowSchema schema; - struct ArrowArray array; - - static int GetSchema(struct ArrowArrayStream* self, struct ArrowSchema* out) { - OneValueStream* stream = static_cast(self->private_data); - return ArrowSchemaDeepCopy(&stream->schema, out); - } - static int GetNext(struct ArrowArrayStream* self, struct ArrowArray* out) { - OneValueStream* stream = static_cast(self->private_data); - *out = stream->array; - stream->array.release = nullptr; - return 0; - } - static const char* GetLastError(struct ArrowArrayStream* self) { return NULL; } - static void Release(struct ArrowArrayStream* self) { - OneValueStream* stream = static_cast(self->private_data); - if (stream->schema.release) { - stream->schema.release(&stream->schema); - stream->schema.release = nullptr; - } - if (stream->array.release) { - stream->array.release(&stream->array); - stream->array.release = nullptr; - } - delete stream; - self->release = nullptr; - } -}; - /// \brief A base implementation of a statement. template -class StatementBase : public ObjectBase { +class Statement : public BaseStatement { public: - using Base = StatementBase; + using Base = Statement; /// \brief What to do in ingestion when the table does not exist. enum class TableDoesNotExist { @@ -103,37 +72,30 @@ class StatementBase : public ObjectBase { /// \brief Statement state: one of the above. using State = std::variant; - StatementBase() : ObjectBase() { + Statement() : BaseStatement() { std::memset(&bind_parameters_, 0, sizeof(bind_parameters_)); } - ~StatementBase() = default; + ~Statement() = default; AdbcStatusCode Bind(ArrowArray* values, ArrowSchema* schema, AdbcError* error) { if (!values || !values->release) { - return status::InvalidArgument("{} Bind: must provide non-NULL array", - Derived::kErrorPrefix) + return status::InvalidArgument(Derived::kErrorPrefix, + " Bind: must provide non-NULL array") .ToAdbc(error); } else if (!schema || !schema->release) { - return status::InvalidArgument("{} Bind: must provide non-NULL stream", - Derived::kErrorPrefix) + return status::InvalidArgument(Derived::kErrorPrefix, + " Bind: must provide non-NULL stream") .ToAdbc(error); } if (bind_parameters_.release) bind_parameters_.release(&bind_parameters_); - // Make a one-value stream - bind_parameters_.private_data = new OneValueStream{*schema, *values}; - bind_parameters_.get_schema = &OneValueStream::GetSchema; - bind_parameters_.get_next = &OneValueStream::GetNext; - bind_parameters_.get_last_error = &OneValueStream::GetLastError; - bind_parameters_.release = &OneValueStream::Release; - std::memset(values, 0, sizeof(*values)); - std::memset(schema, 0, sizeof(*schema)); + MakeArrayStream(schema, values, &bind_parameters_); return ADBC_STATUS_OK; } AdbcStatusCode BindStream(ArrowArrayStream* stream, AdbcError* error) { if (!stream || !stream->release) { - return status::InvalidArgument("{} BindStream: must provide non-NULL stream", - Derived::kErrorPrefix) + return status::InvalidArgument(Derived::kErrorPrefix, + " BindStream: must provide non-NULL stream") .ToAdbc(error); } if (bind_parameters_.release) bind_parameters_.release(&bind_parameters_); @@ -157,14 +119,13 @@ class StatementBase : public ObjectBase { [&](auto&& state) -> AdbcStatusCode { using T = std::decay_t; if constexpr (std::is_same_v) { - return status::InvalidState( - "{} Cannot ExecuteQuery without setting the query", - Derived::kErrorPrefix) + return status::InvalidState(Derived::kErrorPrefix, + " Cannot ExecuteQuery without setting the query") .ToAdbc(error); } else if constexpr (std::is_same_v) { if (stream) { - return status::InvalidState("{} Cannot ingest with result set", - Derived::kErrorPrefix) + return status::InvalidState(Derived::kErrorPrefix, + " Cannot ingest with result set") .ToAdbc(error); } RAISE_RESULT(error, int64_t rows, impl().ExecuteIngestImpl(state)); @@ -201,19 +162,19 @@ class StatementBase : public ObjectBase { using T = std::decay_t; if constexpr (std::is_same_v) { return status::InvalidState( - "{} Cannot GetParameterSchema without setting the query", - Derived::kErrorPrefix) + Derived::kErrorPrefix, + " Cannot GetParameterSchema without setting the query") .ToAdbc(error); } else if constexpr (std::is_same_v) { - return status::InvalidState("{} Cannot GetParameterSchema in bulk ingestion", - Derived::kErrorPrefix) + return status::InvalidState(Derived::kErrorPrefix, + " Cannot GetParameterSchema in bulk ingestion") .ToAdbc(error); } else if constexpr (std::is_same_v) { return impl().GetParameterSchemaImpl(state, schema).ToAdbc(error); } else if constexpr (std::is_same_v) { return status::InvalidState( - "{} Cannot GetParameterSchema without calling Prepare", - Derived::kErrorPrefix) + Derived::kErrorPrefix, + " Cannot GetParameterSchema without calling Prepare") .ToAdbc(error); } else { static_assert(!sizeof(T), "case not implemented"); @@ -223,7 +184,7 @@ class StatementBase : public ObjectBase { } AdbcStatusCode Init(void* parent, AdbcError* error) { - lifecycle_state_ = LifecycleState::kInitialized; + this->lifecycle_state_ = LifecycleState::kInitialized; if (auto status = impl().InitImpl(parent); !status.ok()) { return status.ToAdbc(error); } @@ -236,12 +197,12 @@ class StatementBase : public ObjectBase { using T = std::decay_t; if constexpr (std::is_same_v) { return status::InvalidState( - "{} Cannot Prepare without setting the query", - Derived::kErrorPrefix); + Derived::kErrorPrefix, + " Cannot Prepare without setting the query"); } else if constexpr (std::is_same_v) { return status::InvalidState( - "{} Cannot Prepare without setting the query", - Derived::kErrorPrefix); + Derived::kErrorPrefix, + " Cannot Prepare without setting the query"); } else if constexpr (std::is_same_v) { // No-op return status::Ok(); @@ -291,8 +252,8 @@ class StatementBase : public ObjectBase { state.table_does_not_exist_ = TableDoesNotExist::kCreate; state.table_exists_ = TableExists::kReplace; } else { - return status::InvalidArgument("{} Invalid ingest mode '{}'", - Derived::kErrorPrefix, key, value) + return status::InvalidArgument(Derived::kErrorPrefix, " Invalid ingest mode '", + key, "': ", value.Format()) .ToAdbc(error); } return ADBC_STATUS_OK; @@ -360,46 +321,46 @@ class StatementBase : public ObjectBase { } Result ExecuteIngestImpl(IngestState& state) { - return status::NotImplemented("{} Bulk ingest is not implemented", - Derived::kErrorPrefix); + return status::NotImplemented(Derived::kErrorPrefix, + " Bulk ingest is not implemented"); } Result ExecuteQueryImpl(PreparedState& state, ArrowArrayStream* stream) { - return status::NotImplemented("{} ExecuteQuery is not implemented", - Derived::kErrorPrefix); + return status::NotImplemented(Derived::kErrorPrefix, + " ExecuteQuery is not implemented"); } Result ExecuteQueryImpl(QueryState& state, ArrowArrayStream* stream) { - return status::NotImplemented("{} ExecuteQuery is not implemented", - Derived::kErrorPrefix); + return status::NotImplemented(Derived::kErrorPrefix, + " ExecuteQuery is not implemented"); } Result ExecuteUpdateImpl(PreparedState& state) { - return status::NotImplemented("{} ExecuteQuery (update) is not implemented", - Derived::kErrorPrefix); + return status::NotImplemented(Derived::kErrorPrefix, + " ExecuteQuery (update) is not implemented"); } Result ExecuteUpdateImpl(QueryState& state) { - return status::NotImplemented("{} ExecuteQuery (update) is not implemented", - Derived::kErrorPrefix); + return status::NotImplemented(Derived::kErrorPrefix, + " ExecuteQuery (update) is not implemented"); } Status GetParameterSchemaImpl(PreparedState& state, ArrowSchema* schema) { - return status::NotImplemented("{} GetParameterSchema is not implemented", - Derived::kErrorPrefix); + return status::NotImplemented(Derived::kErrorPrefix, + " GetParameterSchema is not implemented"); } Status InitImpl(void* parent) { return status::Ok(); } Status PrepareImpl(QueryState& state) { - return status::NotImplemented("{} Prepare is not implemented", Derived::kErrorPrefix); + return status::NotImplemented(Derived::kErrorPrefix, " Prepare is not implemented"); } Status ReleaseImpl() { return status::Ok(); } Status SetOptionImpl(std::string_view key, Option value) { - return status::NotImplemented("{} Unknown statement option {}={}", - Derived::kErrorPrefix, key, value); + return status::NotImplemented(Derived::kErrorPrefix, " Unknown statement option ", + key, "=", value.Format()); } protected: diff --git a/c/driver/framework/status.h b/c/driver/framework/status.h index b6375b8e87..22e484dd3c 100644 --- a/c/driver/framework/status.h +++ b/c/driver/framework/status.h @@ -20,13 +20,17 @@ #include #include #include +#include #include #include #include #include -#include +#if defined(ADBC_FRAMEWORK_USE_FMT) #include +#endif + +#include /// \file status.h @@ -65,6 +69,19 @@ class Status { impl_->details.push_back({std::move(key), std::move(value)}); } + /// \brief Set the sqlstate of this status + void SetSqlState(std::string sqlstate) { + assert(impl_ != nullptr); + std::memset(impl_->sql_state, 0, sizeof(impl_->sql_state)); + for (size_t i = 0; i < sqlstate.size(); i++) { + if (i >= sizeof(impl_->sql_state)) { + break; + } + + impl_->sql_state[i] = sqlstate[i]; + } + } + /// \brief Export this status to an AdbcError. AdbcStatusCode ToAdbc(AdbcError* adbc_error) const { if (impl_ == nullptr) return ADBC_STATUS_OK; @@ -108,7 +125,29 @@ class Status { return status; } + // Helpers to create statuses with known codes + static Status Ok() { return Status(); } + +#define STATUS_CTOR(NAME, CODE) \ + template \ + static Status NAME(Args&&... args) { \ + std::stringstream ss; \ + ([&] { ss << args; }(), ...); \ + return Status(ADBC_STATUS_##CODE, ss.str()); \ + } + + STATUS_CTOR(Internal, INTERNAL) + STATUS_CTOR(InvalidArgument, INVALID_ARGUMENT) + STATUS_CTOR(InvalidState, INVALID_STATE) + STATUS_CTOR(IO, IO) + STATUS_CTOR(NotFound, NOT_FOUND) + STATUS_CTOR(NotImplemented, NOT_IMPLEMENTED) + STATUS_CTOR(Unknown, UNKNOWN) + +#undef STATUS_CTOR + private: + /// \brief Private Status implementation details struct Impl { // invariant: code is never OK AdbcStatusCode code; @@ -129,6 +168,8 @@ class Status { template friend class Driver; + // Allow access to these for drivers transitioning to the framework + public: int CDetailCount() const { return impl_ ? static_cast(impl_->details.size()) : 0; } AdbcErrorDetail CDetail(int index) const { @@ -140,6 +181,7 @@ class Status { detail.second.size()}; } + private: static void CRelease(AdbcError* error) { if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { auto* error_obj = reinterpret_cast(error->private_data); @@ -248,11 +290,14 @@ class Result { namespace adbc::driver::status { -#define STATUS_CTOR(NAME, CODE) \ - template \ - static Status NAME(std::string_view format_string, Args&&... args) { \ - auto message = fmt::vformat(format_string, fmt::make_format_args(args...)); \ - return Status(ADBC_STATUS_##CODE, std::move(message)); \ +inline driver::Status Ok() { return driver::Status(); } + +#define STATUS_CTOR(NAME, CODE) \ + template \ + static Status NAME(Args&&... args) { \ + std::stringstream ss; \ + ([&] { ss << args; }(), ...); \ + return Status(ADBC_STATUS_##CODE, ss.str()); \ } // TODO: unit tests for internal utilities @@ -266,27 +311,49 @@ STATUS_CTOR(Unknown, UNKNOWN) #undef STATUS_CTOR -inline driver::Status Ok() { return driver::Status(); } +} // namespace adbc::driver::status -#define UNWRAP_ERRNO_IMPL(NAME, CODE, RHS) \ - auto&& NAME = (RHS); \ - if (NAME != 0) { \ - return adbc::driver::status::CODE("Nanoarrow call failed: {} = ({}) {}", #RHS, NAME, \ - std::strerror(NAME)); \ +#if defined(ADBC_FRAMEWORK_USE_FMT) +namespace adbc::driver::status::fmt { + +#define STATUS_CTOR(NAME, CODE) \ + template \ + static Status NAME(std::string_view format_string, Args&&... args) { \ + auto message = ::fmt::vformat(format_string, ::fmt::make_format_args(args...)); \ + return Status(ADBC_STATUS_##CODE, std::move(message)); \ } -#define UNWRAP_ERRNO(CODE, RHS) \ - UNWRAP_ERRNO_IMPL(UNWRAP_RESULT_NAME(driver_errno, __COUNTER__), CODE, RHS) +// TODO: unit tests for internal utilities +STATUS_CTOR(Internal, INTERNAL) +STATUS_CTOR(InvalidArgument, INVALID_ARGUMENT) +STATUS_CTOR(InvalidState, INVALID_STATE) +STATUS_CTOR(IO, IO) +STATUS_CTOR(NotFound, NOT_FOUND) +STATUS_CTOR(NotImplemented, NOT_IMPLEMENTED) +STATUS_CTOR(Unknown, UNKNOWN) + +#undef STATUS_CTOR + +} // namespace adbc::driver::status::fmt +#endif -#define UNWRAP_NANOARROW_IMPL(NAME, ERROR, CODE, RHS) \ +#define UNWRAP_ERRNO_IMPL(NAME, CODE, RHS) \ auto&& NAME = (RHS); \ if (NAME != 0) { \ - return adbc::driver::status::CODE("Nanoarrow call failed: {} = ({}) {}. {}", #RHS, \ - NAME, std::strerror(NAME), (ERROR).message); \ + return adbc::driver::status::CODE("Call failed: ", #RHS, " = (errno ", NAME, ") ", \ + std::strerror(NAME)); \ + } + +#define UNWRAP_ERRNO(CODE, RHS) \ + UNWRAP_ERRNO_IMPL(UNWRAP_RESULT_NAME(driver_errno, __COUNTER__), CODE, RHS) + +#define UNWRAP_NANOARROW_IMPL(NAME, ERROR, CODE, RHS) \ + auto&& NAME = (RHS); \ + if (NAME != 0) { \ + return adbc::driver::status::CODE("nanoarrow call failed: ", #RHS, " = (", NAME, \ + ") ", std::strerror(NAME), ". ", (ERROR).message); \ } #define UNWRAP_NANOARROW(ERROR, CODE, RHS) \ UNWRAP_NANOARROW_IMPL(UNWRAP_RESULT_NAME(driver_errno_na, __COUNTER__), ERROR, CODE, \ RHS) - -} // namespace adbc::driver::status diff --git a/c/driver/framework/utility.cc b/c/driver/framework/utility.cc new file mode 100644 index 0000000000..d281776e59 --- /dev/null +++ b/c/driver/framework/utility.cc @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "driver/framework/utility.h" + +#include +#include + +#include "arrow-adbc/adbc.h" +#include "nanoarrow/nanoarrow.hpp" + +namespace adbc::driver { + +void MakeEmptyStream(ArrowSchema* schema, ArrowArrayStream* out) { + nanoarrow::EmptyArrayStream(schema).ToArrayStream(out); +} + +void MakeArrayStream(ArrowSchema* schema, ArrowArray* array, ArrowArrayStream* out) { + if (array->length == 0) { + ArrowArrayRelease(array); + std::memset(array, 0, sizeof(ArrowArray)); + + MakeEmptyStream(schema, out); + } else { + nanoarrow::VectorArrayStream(schema, array).ToArrayStream(out); + } +} + +Status MakeTableTypesStream(const std::vector& table_types, + ArrowArrayStream* out) { + nanoarrow::UniqueArray array; + nanoarrow::UniqueSchema schema; + ArrowSchemaInit(schema.get()); + + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema.get(), NANOARROW_TYPE_STRUCT)); + UNWRAP_ERRNO(Internal, ArrowSchemaAllocateChildren(schema.get(), /*num_columns=*/1)); + ArrowSchemaInit(schema.get()->children[0]); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(schema.get()->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema.get()->children[0], "table_type")); + schema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + + UNWRAP_ERRNO(Internal, ArrowArrayInitFromSchema(array.get(), schema.get(), NULL)); + UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array.get())); + + for (std::string const& table_type : table_types) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(array->children[0], + ArrowCharView(table_type.c_str()))); + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(array.get())); + } + + UNWRAP_ERRNO(Internal, ArrowArrayFinishBuildingDefault(array.get(), nullptr)); + MakeArrayStream(schema.get(), array.get(), out); + return status::Ok(); +} + +namespace { +Status MakeGetInfoInit(ArrowSchema* schema, ArrowArray* array) { + ArrowSchemaInit(schema); + UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema, /*num_columns=*/2)); + + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_UINT32)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[0], "info_name")); + schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + + ArrowSchema* info_value = schema->children[1]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeUnion(info_value, NANOARROW_TYPE_DENSE_UNION, 6)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value, "info_value")); + + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[0], "string_value")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[1], NANOARROW_TYPE_BOOL)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[1], "bool_value")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[2], NANOARROW_TYPE_INT64)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[2], "int64_value")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[3], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[3], "int32_bitmask")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[4], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[4], "string_list")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(info_value->children[5], NANOARROW_TYPE_MAP)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(info_value->children[5], "int32_to_int32_list_map")); + + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(info_value->children[4]->children[0], + NANOARROW_TYPE_STRING)); + + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[5]->children[0]->children[0], + NANOARROW_TYPE_INT32)); + info_value->children[5]->children[0]->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[5]->children[0]->children[1], + NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO( + Internal, + ArrowSchemaSetType(info_value->children[5]->children[0]->children[1]->children[0], + NANOARROW_TYPE_INT32)); + + UNWRAP_ERRNO(Internal, ArrowArrayInitFromSchema(array, schema, nullptr)); + UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array)); + + return status::Ok(); +} + +Status MakeGetInfoAppendString(ArrowArray* array, uint32_t info_code, + std::string_view info_value) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendUInt(array->children[0], info_code)); + // Append to type variant + ArrowStringView value; + value.data = info_value.data(); + value.size_bytes = static_cast(info_value.size()); + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(array->children[1]->children[0], value)); + // Append type code/offset + UNWRAP_ERRNO(Internal, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/0)); + return status::Ok(); +} + +Status MakeGetInfoAppendInt(ArrowArray* array, uint32_t info_code, int64_t info_value) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendUInt(array->children[0], info_code)); + // Append to type variant + UNWRAP_ERRNO(Internal, + ArrowArrayAppendInt(array->children[1]->children[2], info_value)); + // Append type code/offset + UNWRAP_ERRNO(Internal, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/2)); + return status::Ok(); +} +} // namespace + +Status MakeGetInfoStream(const std::vector& infos, ArrowArrayStream* out) { + nanoarrow::UniqueSchema schema; + nanoarrow::UniqueArray array; + + UNWRAP_STATUS(MakeGetInfoInit(schema.get(), array.get())); + + for (const auto& info : infos) { + UNWRAP_STATUS(std::visit( + [&](auto&& value) -> Status { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return MakeGetInfoAppendString(array.get(), info.code, value); + } else if constexpr (std::is_same_v) { + return MakeGetInfoAppendInt(array.get(), info.code, value); + } else { + static_assert(!sizeof(T), "info value type not implemented"); + } + }, + info.value)); + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(array.get())); + } + + ArrowError na_error = {0}; + UNWRAP_NANOARROW(na_error, Internal, + ArrowArrayFinishBuildingDefault(array.get(), &na_error)); + MakeArrayStream(schema.get(), array.get(), out); + return status::Ok(); +} + +} // namespace adbc::driver diff --git a/c/driver/framework/utility.h b/c/driver/framework/utility.h new file mode 100644 index 0000000000..af60594ea4 --- /dev/null +++ b/c/driver/framework/utility.h @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include + +#include "driver/framework/status.h" + +namespace adbc::driver { + +/// \brief Create an ArrowArrayStream with zero batches from a given ArrowSchema. +/// \ingroup adbc-framework-catalog +/// +/// This function takes ownership of schema; the caller is responsible for +/// releasing out. +void MakeEmptyStream(ArrowSchema* schema, ArrowArrayStream* out); + +/// \brief Create an ArrowArrayStream from a given ArrowSchema and ArrowArray. +/// \ingroup adbc-framework-catalog +/// +/// The resulting ArrowArrayStream will contain zero batches if the length of the +/// array is zero, or exactly one batch if the length of the array is non-zero. +/// This function takes ownership of schema and array; the caller is responsible for +/// releasing out. +void MakeArrayStream(ArrowSchema* schema, ArrowArray* array, ArrowArrayStream* out); + +/// \brief Create an ArrowArrayStream representation of a vector of table types. +/// \ingroup adbc-framework-catalog +/// +/// Create an ArrowArrayStream representation of an array of table types +/// that can be used to implement AdbcConnectionGetTableTypes(). The caller is responsible +/// for releasing out on success. +Status MakeTableTypesStream(const std::vector& table_types, + ArrowArrayStream* out); + +/// \brief Representation of a single item in an array to be returned +/// from AdbcConnectionGetInfo(). +/// \ingroup adbc-framework-catalog +struct InfoValue { + uint32_t code; + std::variant value; + + InfoValue(uint32_t code, std::variant value) + : code(code), value(std::move(value)) {} + InfoValue(uint32_t code, const char* value) : InfoValue(code, std::string(value)) {} +}; + +/// \brief Create an ArrowArrayStream to be returned from AdbcConnectionGetInfo(). +/// \ingroup adbc-framework-catalog +/// +/// The caller is responsible for releasing out on success. +Status MakeGetInfoStream(const std::vector& infos, ArrowArrayStream* out); + +} // namespace adbc::driver diff --git a/c/driver/postgresql/CMakeLists.txt b/c/driver/postgresql/CMakeLists.txt index e8bbeac9e9..a720696c6a 100644 --- a/c/driver/postgresql/CMakeLists.txt +++ b/c/driver/postgresql/CMakeLists.txt @@ -33,6 +33,7 @@ add_arrow_lib(adbc_driver_postgresql database.cc postgresql.cc result_helper.cc + result_reader.cc statement.cc OUTPUTS ADBC_LIBRARIES @@ -57,8 +58,8 @@ add_arrow_lib(adbc_driver_postgresql foreach(LIB_TARGET ${ADBC_LIBRARIES}) target_compile_definitions(${LIB_TARGET} PRIVATE ADBC_EXPORTING) target_include_directories(${LIB_TARGET} SYSTEM - PRIVATE ${REPOSITORY_ROOT} - ${REPOSITORY_ROOT}/c/ + PRIVATE ${REPOSITORY_ROOT}/c/ + ${REPOSITORY_ROOT}/c/include/ ${LIBPQ_INCLUDE_DIRS} ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver) @@ -86,8 +87,8 @@ if(ADBC_BUILD_TESTS) ${TEST_LINK_LIBS}) target_compile_features(adbc-driver-postgresql-test PRIVATE cxx_std_17) target_include_directories(adbc-driver-postgresql-test SYSTEM - PRIVATE ${REPOSITORY_ROOT} - ${REPOSITORY_ROOT}/c/ + PRIVATE ${REPOSITORY_ROOT}/c/ + ${REPOSITORY_ROOT}/c/include/ ${LIBPQ_INCLUDE_DIRS} ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver) @@ -108,8 +109,8 @@ if(ADBC_BUILD_TESTS) ${TEST_LINK_LIBS}) target_compile_features(adbc-driver-postgresql-copy-test PRIVATE cxx_std_17) target_include_directories(adbc-driver-postgresql-copy-test SYSTEM - PRIVATE ${REPOSITORY_ROOT} - ${REPOSITORY_ROOT}/c/ + PRIVATE ${REPOSITORY_ROOT}/c/ + ${REPOSITORY_ROOT}/c/include/ ${LIBPQ_INCLUDE_DIRS} ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver) @@ -128,7 +129,7 @@ if(ADBC_BUILD_BENCHMARKS) benchmark::benchmark) # add_benchmark replaces _ with - when creating target target_include_directories(postgresql-benchmark - PRIVATE ${REPOSITORY_ROOT} ${REPOSITORY_ROOT}/c/ + PRIVATE ${REPOSITORY_ROOT}/c/ ${REPOSITORY_ROOT}/c/include/ ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver) endif() diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h new file mode 100644 index 0000000000..df0b9d2ca5 --- /dev/null +++ b/c/driver/postgresql/bind_stream.h @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include "copy/writer.h" +#include "error.h" +#include "postgres_type.h" +#include "postgres_util.h" +#include "result_helper.h" + +namespace adbcpq { + +/// The flag indicating to PostgreSQL that we want binary-format values. +constexpr int kPgBinaryFormat = 1; + +/// Helper to manage bind parameters with a prepared statement +struct BindStream { + Handle bind; + Handle array_view; + Handle current; + Handle bind_schema; + int64_t current_row = -1; + + std::vector bind_schema_fields; + std::vector> bind_field_writers; + + // OIDs for parameter types + std::vector param_types; + std::vector param_values; + std::vector param_formats; + std::vector param_lengths; + Handle param_buffer; + + bool has_tz_field = false; + std::string tz_setting; + + struct ArrowError na_error; + + BindStream() { + this->bind->release = nullptr; + std::memset(&na_error, 0, sizeof(na_error)); + } + + void SetBind(struct ArrowArrayStream* stream) { + this->bind.reset(); + ArrowArrayStreamMove(stream, &bind.value); + } + + template + Status Begin(Callback&& callback) { + UNWRAP_NANOARROW( + na_error, Internal, + ArrowArrayStreamGetSchema(&bind.value, &bind_schema.value, &na_error)); + + struct ArrowSchemaView bind_schema_view; + UNWRAP_NANOARROW( + na_error, Internal, + ArrowSchemaViewInit(&bind_schema_view, &bind_schema.value, &na_error)); + if (bind_schema_view.type != ArrowType::NANOARROW_TYPE_STRUCT) { + return Status::InvalidState("[libpq] Bind parameters must have type STRUCT"); + } + + bind_schema_fields.resize(bind_schema->n_children); + for (size_t i = 0; i < bind_schema_fields.size(); i++) { + UNWRAP_ERRNO(Internal, + ArrowSchemaViewInit(&bind_schema_fields[i], bind_schema->children[i], + /*error*/ nullptr)); + } + + UNWRAP_NANOARROW( + na_error, Internal, + ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, &na_error)); + + ArrowBufferInit(¶m_buffer.value); + + return std::move(callback)(); + } + + Status SetParamTypes(PGconn* pg_conn, const PostgresTypeResolver& type_resolver, + const bool autocommit) { + param_types.resize(bind_schema->n_children); + param_values.resize(bind_schema->n_children); + param_lengths.resize(bind_schema->n_children); + param_formats.resize(bind_schema->n_children, kPgBinaryFormat); + bind_field_writers.resize(bind_schema->n_children); + + for (size_t i = 0; i < bind_field_writers.size(); i++) { + PostgresType type; + UNWRAP_NANOARROW(na_error, Internal, + PostgresType::FromSchema(type_resolver, bind_schema->children[i], + &type, &na_error)); + + // tz-aware timestamps require special handling to set the timezone to UTC + // prior to sending over the binary protocol; must be reset after execute + if (!has_tz_field && type.type_id() == PostgresTypeId::kTimestamptz) { + UNWRAP_STATUS(SetDatabaseTimezoneUTC(pg_conn, autocommit)); + has_tz_field = true; + } + + std::unique_ptr writer; + UNWRAP_NANOARROW( + na_error, Internal, + MakeCopyFieldWriter(bind_schema->children[i], array_view->children[i], + type_resolver, &writer, &na_error)); + + param_types[i] = type.oid(); + param_formats[i] = kPgBinaryFormat; + bind_field_writers[i] = std::move(writer); + } + + return Status::Ok(); + } + + Status SetDatabaseTimezoneUTC(PGconn* pg_conn, const bool autocommit) { + if (autocommit) { + PqResultHelper helper(pg_conn, "BEGIN"); + UNWRAP_STATUS(helper.Execute()); + } + + PqResultHelper get_tz(pg_conn, "SELECT current_setting('TIMEZONE')"); + UNWRAP_STATUS(get_tz.Execute()); + for (auto row : get_tz) { + tz_setting = row[0].value(); + } + + PqResultHelper set_utc(pg_conn, "SET TIME ZONE 'UTC'"); + UNWRAP_STATUS(set_utc.Execute()); + + return Status::Ok(); + } + + Status Prepare(PGconn* pg_conn, const std::string& query) { + PqResultHelper helper(pg_conn, query); + UNWRAP_STATUS(helper.Prepare(param_types)); + return Status::Ok(); + } + + Status PullNextArray() { + if (current->release != nullptr) ArrowArrayRelease(¤t.value); + + UNWRAP_NANOARROW(na_error, IO, + ArrowArrayStreamGetNext(&bind.value, ¤t.value, &na_error)); + + if (current->release != nullptr) { + UNWRAP_NANOARROW( + na_error, Internal, + ArrowArrayViewSetArray(&array_view.value, ¤t.value, &na_error)); + } + + return Status::Ok(); + } + + Status EnsureNextRow() { + if (current->release != nullptr) { + current_row++; + if (current_row < current->length) { + return Status::Ok(); + } + } + + // Pull until we have an array with at least one row or the stream is finished + do { + UNWRAP_STATUS(PullNextArray()); + if (current->release == nullptr) { + current_row = -1; + return Status::Ok(); + } + } while (current->length == 0); + + current_row = 0; + return Status::Ok(); + } + + Status BindAndExecuteCurrentRow(PGconn* pg_conn, PGresult** result_out, + int result_format) { + param_buffer->size_bytes = 0; + int64_t last_offset = 0; + + for (int64_t col = 0; col < array_view->n_children; col++) { + if (!ArrowArrayViewIsNull(array_view->children[col], current_row)) { + // Note that this Write() call currently writes the (int32_t) byte size of the + // field in addition to the serialized value. + UNWRAP_NANOARROW( + na_error, Internal, + bind_field_writers[col]->Write(¶m_buffer.value, current_row, &na_error)); + } else { + UNWRAP_ERRNO(Internal, ArrowBufferAppendInt32(¶m_buffer.value, 0)); + } + + int64_t param_length = param_buffer->size_bytes - last_offset - sizeof(int32_t); + if (param_length > (std::numeric_limits::max)()) { + return Status::Internal("Paramter ", col, "serialized to >2GB of binary"); + } + + param_lengths[col] = static_cast(param_length); + last_offset = param_buffer->size_bytes; + } + + last_offset = 0; + for (int64_t col = 0; col < array_view->n_children; col++) { + last_offset += sizeof(int32_t); + if (param_lengths[col] == 0) { + param_values[col] = nullptr; + } else { + param_values[col] = reinterpret_cast(param_buffer->data) + last_offset; + } + last_offset += param_lengths[col]; + } + + PGresult* result = + PQexecPrepared(pg_conn, /*stmtName=*/"", + /*nParams=*/bind_schema->n_children, param_values.data(), + param_lengths.data(), param_formats.data(), result_format); + + ExecStatusType pg_status = PQresultStatus(result); + if (pg_status != PGRES_COMMAND_OK && pg_status != PGRES_TUPLES_OK) { + Status status = + MakeStatus(result, "[libpq] Failed to execute prepared statement: {} {}", + PQresStatus(pg_status), PQerrorMessage(pg_conn)); + PQclear(result); + return status; + } + + *result_out = result; + return Status::Ok(); + } + + Status Cleanup(PGconn* pg_conn) { + if (has_tz_field) { + PqResultHelper reset(pg_conn, "SET TIME ZONE '" + tz_setting + "'"); + UNWRAP_STATUS(reset.Execute()); + + PqResultHelper commit(pg_conn, "COMMIT"); + UNWRAP_STATUS(reset.Execute()); + } + + return Status::Ok(); + } + + Status ExecuteCopy(PGconn* pg_conn, const PostgresTypeResolver& type_resolver, + int64_t* rows_affected) { + if (rows_affected) *rows_affected = 0; + + PostgresCopyStreamWriter writer; + UNWRAP_ERRNO(Internal, writer.Init(&bind_schema.value)); + UNWRAP_NANOARROW(na_error, Internal, + writer.InitFieldWriters(type_resolver, &na_error)); + + UNWRAP_NANOARROW(na_error, Internal, writer.WriteHeader(&na_error)); + + while (true) { + UNWRAP_STATUS(PullNextArray()); + if (!current->release) break; + + UNWRAP_ERRNO(Internal, writer.SetArray(¤t.value)); + + // build writer buffer + int write_result; + do { + write_result = writer.WriteRecord(&na_error); + } while (write_result == NANOARROW_OK); + + // check if not ENODATA at exit + if (write_result != ENODATA) { + return Status::IO("Error occurred writing COPY data: ", PQerrorMessage(pg_conn)); + } + + UNWRAP_STATUS(FlushCopyWriterToConn(pg_conn, writer)); + + if (rows_affected) *rows_affected += current->length; + writer.Rewind(); + } + + // If there were no arrays in the stream, we haven't flushed yet + UNWRAP_STATUS(FlushCopyWriterToConn(pg_conn, writer)); + + if (PQputCopyEnd(pg_conn, NULL) <= 0) { + return Status::IO("Error message returned by PQputCopyEnd: ", + PQerrorMessage(pg_conn)); + } + + PGresult* result = PQgetResult(pg_conn); + ExecStatusType pg_status = PQresultStatus(result); + if (pg_status != PGRES_COMMAND_OK) { + Status status = + MakeStatus(result, "[libpq] Failed to execute COPY statement: {} {}", + PQresStatus(pg_status), PQerrorMessage(pg_conn)); + PQclear(result); + return status; + } + + PQclear(result); + return Status::Ok(); + } + + Status FlushCopyWriterToConn(PGconn* pg_conn, const PostgresCopyStreamWriter& writer) { + // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max + // size for a single message that we need to respect (1 GiB - 1). Since + // the buffer can be chunked up as much as we want, go for 16 MiB as our + // limit. + // https://github.com/postgres/postgres/blob/23c5a0e7d43bc925c6001538f04a458933a11fc1/src/common/stringinfo.c#L28 + constexpr int64_t kMaxCopyBufferSize = 0x1000000; + ArrowBuffer buffer = writer.WriteBuffer(); + + auto* data = reinterpret_cast(buffer.data); + int64_t remaining = buffer.size_bytes; + while (remaining > 0) { + int64_t to_write = std::min(remaining, kMaxCopyBufferSize); + if (PQputCopyData(pg_conn, data, to_write) <= 0) { + return Status::IO("Error writing tuple field data: ", PQerrorMessage(pg_conn)); + } + remaining -= to_write; + data += to_write; + } + + return Status::Ok(); + } +}; +} // namespace adbcpq diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc index 5a6cdf2412..0f46fff0d7 100644 --- a/c/driver/postgresql/connection.cc +++ b/c/driver/postgresql/connection.cc @@ -17,26 +17,33 @@ #include "connection.h" +#include #include #include #include #include #include +#include #include #include +#include #include #include #include -#include +#include #include #include "database.h" #include "driver/common/utils.h" -#include "driver/framework/catalog.h" +#include "driver/framework/objects.h" +#include "driver/framework/utility.h" #include "error.h" #include "result_helper.h" +using adbc::driver::Result; +using adbc::driver::Status; + namespace adbcpq { namespace { @@ -50,558 +57,392 @@ static const std::unordered_map kPgTableTypes = { {"table", "r"}, {"view", "v"}, {"materialized_view", "m"}, {"toast_table", "t"}, {"foreign_table", "f"}, {"partitioned_table", "p"}}; -class PqGetObjectsHelper { +static const char* kCatalogQueryAll = "SELECT datname FROM pg_catalog.pg_database"; + +// catalog_name is not a parameter here or on any other queries +// because it will always be the currently connected database. +static const char* kSchemaQueryAll = + "SELECT nspname FROM pg_catalog.pg_namespace WHERE " + "nspname !~ '^pg_' AND nspname <> 'information_schema'"; + +// Parameterized on schema_name, relkind +// Note that when binding relkind as a string it must look like {"r", "v", ...} +// (i.e., double quotes). Binding a binary list element also works. +static const char* kTablesQueryAll = + "SELECT c.relname, CASE c.relkind WHEN 'r' THEN 'table' WHEN 'v' THEN 'view' " + "WHEN 'm' THEN 'materialized view' WHEN 't' THEN 'TOAST table' " + "WHEN 'f' THEN 'foreign table' WHEN 'p' THEN 'partitioned table' END " + "AS reltype FROM pg_catalog.pg_class c " + "LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace " + "WHERE pg_catalog.pg_table_is_visible(c.oid) AND n.nspname = $1 AND c.relkind = " + "ANY($2)"; + +// Parameterized on schema_name, table_name +static const char* kColumnsQueryAll = + "SELECT attr.attname, attr.attnum, " + "pg_catalog.col_description(cls.oid, attr.attnum) " + "FROM pg_catalog.pg_attribute AS attr " + "INNER JOIN pg_catalog.pg_class AS cls ON attr.attrelid = cls.oid " + "INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " + "WHERE attr.attnum > 0 AND NOT attr.attisdropped " + "AND nsp.nspname LIKE $1 AND cls.relname LIKE $2"; + +// Parameterized on schema_name, table_name +static const char* kConstraintsQueryAll = + "WITH fk_unnest AS ( " + " SELECT " + " con.conname, " + " 'FOREIGN KEY' AS contype, " + " conrelid, " + " UNNEST(con.conkey) AS conkey, " + " confrelid, " + " UNNEST(con.confkey) AS confkey " + " FROM pg_catalog.pg_constraint AS con " + " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = conrelid " + " INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " + " WHERE con.contype = 'f' AND nsp.nspname = $1 " + " AND cls.relname = $2 " + "), " + "fk_names AS ( " + " SELECT " + " fk_unnest.conname, " + " fk_unnest.contype, " + " fk_unnest.conkey, " + " fk_unnest.confkey, " + " attr.attname, " + " fnsp.nspname AS fschema, " + " fcls.relname AS ftable, " + " fattr.attname AS fattname " + " FROM fk_unnest " + " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = fk_unnest.conrelid " + " INNER JOIN pg_catalog.pg_class AS fcls ON fcls.oid = fk_unnest.confrelid " + " INNER JOIN pg_catalog.pg_namespace AS fnsp ON fnsp.oid = fcls.relnamespace" + " INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = " + "fk_unnest.conkey " + " AND attr.attrelid = fk_unnest.conrelid " + " LEFT JOIN pg_catalog.pg_attribute AS fattr ON fattr.attnum = " + "fk_unnest.confkey " + " AND fattr.attrelid = fk_unnest.confrelid " + "), " + "fkeys AS ( " + " SELECT " + " conname, " + " contype, " + " ARRAY_AGG(attname ORDER BY conkey) AS colnames, " + " fschema, " + " ftable, " + " ARRAY_AGG(fattname ORDER BY confkey) AS fcolnames " + " FROM fk_names " + " GROUP BY " + " conname, " + " contype, " + " fschema, " + " ftable " + "), " + "other_constraints AS ( " + " SELECT con.conname, CASE con.contype WHEN 'c' THEN 'CHECK' WHEN 'u' THEN " + " 'UNIQUE' WHEN 'p' THEN 'PRIMARY KEY' END AS contype, " + " ARRAY_AGG(attr.attname) AS colnames " + " FROM pg_catalog.pg_constraint AS con " + " CROSS JOIN UNNEST(conkey) AS conkeys " + " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = con.conrelid " + " INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " + " INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = conkeys " + " AND cls.oid = attr.attrelid " + " WHERE con.contype IN ('c', 'u', 'p') AND nsp.nspname = $1 " + " AND cls.relname = $2 " + " GROUP BY conname, contype " + ") " + "SELECT " + " conname, contype, colnames, fschema, ftable, fcolnames " + "FROM fkeys " + "UNION ALL " + "SELECT " + " conname, contype, colnames, NULL, NULL, NULL " + "FROM other_constraints"; + +class PostgresGetObjectsHelper : public adbc::driver::GetObjectsHelper { public: - PqGetObjectsHelper(PGconn* conn, int depth, const char* catalog, const char* db_schema, - const char* table_name, const char** table_types, - const char* column_name, struct ArrowSchema* schema, - struct ArrowArray* array, struct AdbcError* error) - : conn_(conn), - depth_(depth), - catalog_(catalog), - db_schema_(db_schema), - table_name_(table_name), - table_types_(table_types), - column_name_(column_name), - schema_(schema), - array_(array), - error_(error) { - na_error_ = {0}; - } - - AdbcStatusCode GetObjects() { - RAISE_ADBC(InitArrowArray()); - - catalog_name_col_ = array_->children[0]; - catalog_db_schemas_col_ = array_->children[1]; - catalog_db_schemas_items_ = catalog_db_schemas_col_->children[0]; - db_schema_name_col_ = catalog_db_schemas_items_->children[0]; - db_schema_tables_col_ = catalog_db_schemas_items_->children[1]; - schema_table_items_ = db_schema_tables_col_->children[0]; - table_name_col_ = schema_table_items_->children[0]; - table_type_col_ = schema_table_items_->children[1]; - - table_columns_col_ = schema_table_items_->children[2]; - table_columns_items_ = table_columns_col_->children[0]; - column_name_col_ = table_columns_items_->children[0]; - column_position_col_ = table_columns_items_->children[1]; - column_remarks_col_ = table_columns_items_->children[2]; - - table_constraints_col_ = schema_table_items_->children[3]; - table_constraints_items_ = table_constraints_col_->children[0]; - constraint_name_col_ = table_constraints_items_->children[0]; - constraint_type_col_ = table_constraints_items_->children[1]; - - constraint_column_names_col_ = table_constraints_items_->children[2]; - constraint_column_name_col_ = constraint_column_names_col_->children[0]; - - constraint_column_usages_col_ = table_constraints_items_->children[3]; - constraint_column_usage_items_ = constraint_column_usages_col_->children[0]; - fk_catalog_col_ = constraint_column_usage_items_->children[0]; - fk_db_schema_col_ = constraint_column_usage_items_->children[1]; - fk_table_col_ = constraint_column_usage_items_->children[2]; - fk_column_name_col_ = constraint_column_usage_items_->children[3]; - - RAISE_ADBC(AppendCatalogs()); - RAISE_ADBC(FinishArrowArray()); - return ADBC_STATUS_OK; + explicit PostgresGetObjectsHelper(PGconn* conn) + : current_database_(PQdb(conn)), + all_catalogs_(conn, kCatalogQueryAll), + some_catalogs_(conn, CatalogQuery()), + all_schemas_(conn, kSchemaQueryAll), + some_schemas_(conn, SchemaQuery()), + all_tables_(conn, kTablesQueryAll), + some_tables_(conn, TablesQuery()), + all_columns_(conn, kColumnsQueryAll), + some_columns_(conn, ColumnsQuery()), + all_constraints_(conn, kConstraintsQueryAll), + some_constraints_(conn, ConstraintsQuery()) {} + + // Allow Redshift to execute this query without constraints + // TODO(paleolimbot): Investigate to see if we can simplify the constraits query so that + // it works on both! + void SetEnableConstraints(bool enable_constraints) { + enable_constraints_ = enable_constraints; } - private: - AdbcStatusCode InitArrowArray() { - RAISE_ADBC(adbc::driver::AdbcInitConnectionObjectsSchema(schema_).ToAdbc(error_)); - - CHECK_NA_DETAIL(INTERNAL, ArrowArrayInitFromSchema(array_, schema_, &na_error_), - &na_error_, error_); - - CHECK_NA(INTERNAL, ArrowArrayStartAppending(array_), error_); - return ADBC_STATUS_OK; + Status Load(adbc::driver::GetObjectsDepth depth, + std::optional catalog_filter, + std::optional schema_filter, + std::optional table_filter, + std::optional column_filter, + const std::vector& table_types) override { + return Status::Ok(); } - AdbcStatusCode AppendSchemas(std::string db_name) { - // postgres only allows you to list schemas for the currently connected db - if (!strcmp(db_name.c_str(), PQdb(conn_))) { - struct StringBuilder query; - std::memset(&query, 0, sizeof(query)); - if (StringBuilderInit(&query, /*initial_size*/ 256)) { - return ADBC_STATUS_INTERNAL; - } - - const char* stmt = - "SELECT nspname FROM pg_catalog.pg_namespace WHERE " - "nspname !~ '^pg_' AND nspname <> 'information_schema'"; - - if (StringBuilderAppend(&query, "%s", stmt)) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } - - std::vector params; - if (db_schema_ != NULL) { - if (StringBuilderAppend(&query, "%s", " AND nspname = $1")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } - params.push_back(db_schema_); - } - - auto result_helper = - PqResultHelper{conn_, std::string(query.buffer), params, error_}; - StringBuilderReset(&query); + Status LoadCatalogs(std::optional catalog_filter) override { + if (catalog_filter.has_value()) { + UNWRAP_STATUS(some_catalogs_.Execute({std::string(*catalog_filter)})); + next_catalog_ = some_catalogs_.Row(-1); + } else { + UNWRAP_STATUS(all_catalogs_.Execute()); + next_catalog_ = all_catalogs_.Row(-1); + } - RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); + return Status::Ok(); + }; - for (PqResultRow row : result_helper) { - const char* schema_name = row[0].data; - CHECK_NA(INTERNAL, - ArrowArrayAppendString(db_schema_name_col_, ArrowCharView(schema_name)), - error_); - if (depth_ == ADBC_OBJECT_DEPTH_DB_SCHEMAS) { - CHECK_NA(INTERNAL, ArrowArrayAppendNull(db_schema_tables_col_, 1), error_); - } else { - RAISE_ADBC(AppendTables(std::string(schema_name))); - } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_items_), error_); - } + Result> NextCatalog() override { + next_catalog_ = next_catalog_.Next(); + if (!next_catalog_.IsValid()) { + return std::nullopt; } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_col_), error_); - return ADBC_STATUS_OK; + return next_catalog_[0].value(); } - AdbcStatusCode AppendCatalogs() { - struct StringBuilder query; - std::memset(&query, 0, sizeof(query)); - if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return ADBC_STATUS_INTERNAL; - - if (StringBuilderAppend(&query, "%s", "SELECT datname FROM pg_catalog.pg_database")) { - return ADBC_STATUS_INTERNAL; + Status LoadSchemas(std::string_view catalog, + std::optional schema_filter) override { + // PostgreSQL can only list for the current database + if (catalog != current_database_) { + return Status::Ok(); } - std::vector params; - if (catalog_ != NULL) { - if (StringBuilderAppend(&query, "%s", " WHERE datname = $1")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } - params.push_back(catalog_); + if (schema_filter.has_value()) { + UNWRAP_STATUS(some_schemas_.Execute({std::string(*schema_filter)})); + next_schema_ = some_schemas_.Row(-1); + } else { + UNWRAP_STATUS(all_schemas_.Execute()); + next_schema_ = all_schemas_.Row(-1); } + return Status::Ok(); + }; - PqResultHelper result_helper = - PqResultHelper{conn_, std::string(query.buffer), params, error_}; - StringBuilderReset(&query); - - RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); - - for (PqResultRow row : result_helper) { - const char* db_name = row[0].data; - CHECK_NA(INTERNAL, - ArrowArrayAppendString(catalog_name_col_, ArrowCharView(db_name)), error_); - if (depth_ == ADBC_OBJECT_DEPTH_CATALOGS) { - CHECK_NA(INTERNAL, ArrowArrayAppendNull(catalog_db_schemas_col_, 1), error_); - } else { - RAISE_ADBC(AppendSchemas(std::string(db_name))); - } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(array_), error_); + Result> NextSchema() override { + next_schema_ = next_schema_.Next(); + if (!next_schema_.IsValid()) { + return std::nullopt; } - return ADBC_STATUS_OK; + return next_schema_[0].value(); } - AdbcStatusCode AppendTables(std::string schema_name) { - struct StringBuilder query; - std::memset(&query, 0, sizeof(query)); - if (StringBuilderInit(&query, /*initial_size*/ 512)) { - return ADBC_STATUS_INTERNAL; - } + Status LoadTables(std::string_view catalog, std::string_view schema, + std::optional table_filter, + const std::vector& table_types) override { + std::string table_types_bind = TableTypesArrayLiteral(table_types); - std::vector params = {schema_name}; - const char* stmt = - "SELECT c.relname, CASE c.relkind WHEN 'r' THEN 'table' WHEN 'v' THEN 'view' " - "WHEN 'm' THEN 'materialized view' WHEN 't' THEN 'TOAST table' " - "WHEN 'f' THEN 'foreign table' WHEN 'p' THEN 'partitioned table' END " - "AS reltype FROM pg_catalog.pg_class c " - "LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace " - "WHERE c.relkind IN ('r','v','m','t','f','p') " - "AND pg_catalog.pg_table_is_visible(c.oid) AND n.nspname = $1"; - - if (StringBuilderAppend(&query, "%s", stmt)) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; + if (table_filter.has_value()) { + UNWRAP_STATUS(some_tables_.Execute( + {std::string(schema), table_types_bind, std::string(*table_filter)})); + next_table_ = some_tables_.Row(-1); + } else { + UNWRAP_STATUS(all_tables_.Execute({std::string(schema), table_types_bind})); + next_table_ = all_tables_.Row(-1); } - if (table_name_ != nullptr) { - if (StringBuilderAppend(&query, "%s", " AND c.relname LIKE $2")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } + return Status::Ok(); + }; - params.push_back(std::string(table_name_)); + Result> NextTable() override { + next_table_ = next_table_.Next(); + if (!next_table_.IsValid()) { + return std::nullopt; } - if (table_types_ != nullptr) { - std::vector table_type_filter; - const char** table_types = table_types_; - while (*table_types != NULL) { - auto table_type_str = std::string(*table_types); - auto search = kPgTableTypes.find(table_type_str); - if (search != kPgTableTypes.end()) { - table_type_filter.push_back(search->second); - } - table_types++; - } + return Table{next_table_[0].value(), next_table_[1].value()}; + } - if (!table_type_filter.empty()) { - std::ostringstream oss; - bool first = true; - oss << "("; - for (const auto& str : table_type_filter) { - if (!first) { - oss << ", "; - } - oss << "'" << str << "'"; - first = false; - } - oss << ")"; + Status LoadColumns(std::string_view catalog, std::string_view schema, + std::string_view table, + std::optional column_filter) override { + if (column_filter.has_value()) { + UNWRAP_STATUS(some_columns_.Execute( + {std::string(schema), std::string(table), std::string(*column_filter)})); + next_column_ = some_columns_.Row(-1); + } else { + UNWRAP_STATUS(all_columns_.Execute({std::string(schema), std::string(table)})); + next_column_ = all_columns_.Row(-1); + } - if (StringBuilderAppend(&query, "%s%s", " AND c.relkind IN ", - oss.str().c_str())) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } + if (enable_constraints_) { + if (column_filter.has_value()) { + UNWRAP_STATUS(some_constraints_.Execute( + {std::string(schema), std::string(table), std::string(*column_filter)})) + next_constraint_ = some_constraints_.Row(-1); } else { - // no matching table type means no records should come back - if (StringBuilderAppend(&query, "%s", " AND false")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } + UNWRAP_STATUS( + all_constraints_.Execute({std::string(schema), std::string(table)})); + next_constraint_ = all_constraints_.Row(-1); } } - auto result_helper = PqResultHelper{conn_, query.buffer, params, error_}; - StringBuilderReset(&query); + return Status::Ok(); + }; - RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); - for (PqResultRow row : result_helper) { - const char* table_name = row[0].data; - const char* table_type = row[1].data; + Result> NextColumn() override { + next_column_ = next_column_.Next(); + if (!next_column_.IsValid()) { + return std::nullopt; + } - CHECK_NA(INTERNAL, - ArrowArrayAppendString(table_name_col_, ArrowCharView(table_name)), - error_); - CHECK_NA(INTERNAL, - ArrowArrayAppendString(table_type_col_, ArrowCharView(table_type)), - error_); - if (depth_ == ADBC_OBJECT_DEPTH_TABLES) { - CHECK_NA(INTERNAL, ArrowArrayAppendNull(table_columns_col_, 1), error_); - CHECK_NA(INTERNAL, ArrowArrayAppendNull(table_constraints_col_, 1), error_); - } else { - auto table_name_s = std::string(table_name); - RAISE_ADBC(AppendColumns(schema_name, table_name_s)); - RAISE_ADBC(AppendConstraints(schema_name, table_name_s)); - } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(schema_table_items_), error_); + Column col; + col.column_name = next_column_[0].value(); + UNWRAP_RESULT(int64_t ordinal_position, next_column_[1].ParseInteger()); + col.ordinal_position = static_cast(ordinal_position); + if (!next_column_[2].is_null) { + col.remarks = next_column_[2].value(); } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_tables_col_), error_); - return ADBC_STATUS_OK; + return col; } - AdbcStatusCode AppendColumns(std::string schema_name, std::string table_name) { - struct StringBuilder query; - std::memset(&query, 0, sizeof(query)); - if (StringBuilderInit(&query, /*initial_size*/ 512)) { - return ADBC_STATUS_INTERNAL; - } - - std::vector params = {schema_name, table_name}; - const char* stmt = - "SELECT attr.attname, attr.attnum, " - "pg_catalog.col_description(cls.oid, attr.attnum) " - "FROM pg_catalog.pg_attribute AS attr " - "INNER JOIN pg_catalog.pg_class AS cls ON attr.attrelid = cls.oid " - "INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " - "WHERE attr.attnum > 0 AND NOT attr.attisdropped " - "AND nsp.nspname LIKE $1 AND cls.relname LIKE $2"; - - if (StringBuilderAppend(&query, "%s", stmt)) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; + Result> NextConstraint() override { + next_constraint_ = next_constraint_.Next(); + if (!next_constraint_.IsValid()) { + return std::nullopt; } - if (column_name_ != NULL) { - if (StringBuilderAppend(&query, "%s", " AND attr.attname LIKE $3")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } + Constraint out; + out.name = next_constraint_[0].data; + out.type = next_constraint_[1].data; - params.push_back(std::string(column_name_)); + UNWRAP_RESULT(constraint_fcolumn_names_, next_constraint_[2].ParseTextArray()); + std::vector fcolumn_names_view; + for (const std::string& item : constraint_fcolumn_names_) { + fcolumn_names_view.push_back(item); } + out.column_names = std::move(fcolumn_names_view); - auto result_helper = PqResultHelper{conn_, query.buffer, params, error_}; - StringBuilderReset(&query); + if (out.type == "FOREIGN KEY") { + assert(!next_constraint_[3].is_null); + assert(!next_constraint_[3].is_null); + assert(!next_constraint_[4].is_null); + assert(!next_constraint_[5].is_null); - RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); + out.usage = std::vector(); + UNWRAP_RESULT(constraint_fkey_names_, next_constraint_[5].ParseTextArray()); - for (PqResultRow row : result_helper) { - const char* column_name = row[0].data; - const char* position = row[1].data; - - CHECK_NA(INTERNAL, - ArrowArrayAppendString(column_name_col_, ArrowCharView(column_name)), - error_); - int ival = atol(position); - CHECK_NA(INTERNAL, - ArrowArrayAppendInt(column_position_col_, static_cast(ival)), - error_); - if (row[2].is_null) { - CHECK_NA(INTERNAL, ArrowArrayAppendNull(column_remarks_col_, 1), error_); - } else { - const char* remarks = row[2].data; - CHECK_NA(INTERNAL, - ArrowArrayAppendString(column_remarks_col_, ArrowCharView(remarks)), - error_); - } + for (const auto& item : constraint_fkey_names_) { + ConstraintUsage usage; + usage.catalog = current_database_; + usage.schema = next_constraint_[3].data; + usage.table = next_constraint_[4].data; + usage.column = item; - // no xdbc_ values for now - for (auto i = 3; i < 19; i++) { - CHECK_NA(INTERNAL, ArrowArrayAppendNull(table_columns_items_->children[i], 1), - error_); + out.usage->push_back(usage); } - - CHECK_NA(INTERNAL, ArrowArrayFinishElement(table_columns_items_), error_); } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(table_columns_col_), error_); - return ADBC_STATUS_OK; + return out; } - // libpq PQexecParams can use either text or binary transfers - // For now we are using text transfer internally, so arrays are sent - // back like {element1, element2} within a const char* - std::vector PqTextArrayToVector(std::string text_array) { - text_array.erase(0, 1); - text_array.erase(text_array.size() - 1); - - std::vector elements; - std::stringstream ss(std::move(text_array)); - std::string tmp; - - while (getline(ss, tmp, ',')) { - elements.push_back(std::move(tmp)); - } - - return elements; + private: + std::string current_database_; + + // Ready-to-Execute() queries + PqResultHelper all_catalogs_; + PqResultHelper some_catalogs_; + PqResultHelper all_schemas_; + PqResultHelper some_schemas_; + PqResultHelper all_tables_; + PqResultHelper some_tables_; + PqResultHelper all_columns_; + PqResultHelper some_columns_; + PqResultHelper all_constraints_; + PqResultHelper some_constraints_; + + // On Redshift, the constraints query fails + bool enable_constraints_{true}; + + // Iterator state for the catalogs/schema/table/column queries + PqResultRow next_catalog_; + PqResultRow next_schema_; + PqResultRow next_table_; + PqResultRow next_column_; + PqResultRow next_constraint_; + + // Owning variants required because the framework versions of these + // are all based on string_view and the result helper can only parse arrays + // into std::vector. + std::vector constraint_fcolumn_names_; + std::vector constraint_fkey_names_; + + // Queries that are slightly modified versions of the generic queries that allow + // the filter for that level to be passed through as a parameter. Defined here + // because global strings should be const char* according to cpplint and using + // the + operator to concatenate them is the most concise way to construct them. + + // Parameterized on catalog_name + static std::string CatalogQuery() { + return std::string(kCatalogQueryAll) + " WHERE datname = $1"; } - AdbcStatusCode AppendConstraints(std::string schema_name, std::string table_name) { - struct StringBuilder query; - std::memset(&query, 0, sizeof(query)); - if (StringBuilderInit(&query, /*initial_size*/ 4096)) { - return ADBC_STATUS_INTERNAL; - } - - std::vector params = {schema_name, table_name}; - const char* stmt = - "WITH fk_unnest AS ( " - " SELECT " - " con.conname, " - " 'FOREIGN KEY' AS contype, " - " conrelid, " - " UNNEST(con.conkey) AS conkey, " - " confrelid, " - " UNNEST(con.confkey) AS confkey " - " FROM pg_catalog.pg_constraint AS con " - " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = conrelid " - " INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " - " WHERE con.contype = 'f' AND nsp.nspname LIKE $1 " - " AND cls.relname LIKE $2 " - "), " - "fk_names AS ( " - " SELECT " - " fk_unnest.conname, " - " fk_unnest.contype, " - " fk_unnest.conkey, " - " fk_unnest.confkey, " - " attr.attname, " - " fnsp.nspname AS fschema, " - " fcls.relname AS ftable, " - " fattr.attname AS fattname " - " FROM fk_unnest " - " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = fk_unnest.conrelid " - " INNER JOIN pg_catalog.pg_class AS fcls ON fcls.oid = fk_unnest.confrelid " - " INNER JOIN pg_catalog.pg_namespace AS fnsp ON fnsp.oid = fcls.relnamespace" - " INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = " - "fk_unnest.conkey " - " AND attr.attrelid = fk_unnest.conrelid " - " LEFT JOIN pg_catalog.pg_attribute AS fattr ON fattr.attnum = " - "fk_unnest.confkey " - " AND fattr.attrelid = fk_unnest.confrelid " - "), " - "fkeys AS ( " - " SELECT " - " conname, " - " contype, " - " ARRAY_AGG(attname ORDER BY conkey) AS colnames, " - " fschema, " - " ftable, " - " ARRAY_AGG(fattname ORDER BY confkey) AS fcolnames " - " FROM fk_names " - " GROUP BY " - " conname, " - " contype, " - " fschema, " - " ftable " - "), " - "other_constraints AS ( " - " SELECT con.conname, CASE con.contype WHEN 'c' THEN 'CHECK' WHEN 'u' THEN " - " 'UNIQUE' WHEN 'p' THEN 'PRIMARY KEY' END AS contype, " - " ARRAY_AGG(attr.attname) AS colnames " - " FROM pg_catalog.pg_constraint AS con " - " CROSS JOIN UNNEST(conkey) AS conkeys " - " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = con.conrelid " - " INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " - " INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = conkeys " - " AND cls.oid = attr.attrelid " - " WHERE con.contype IN ('c', 'u', 'p') AND nsp.nspname LIKE $1 " - " AND cls.relname LIKE $2 " - " GROUP BY conname, contype " - ") " - "SELECT " - " conname, contype, colnames, fschema, ftable, fcolnames " - "FROM fkeys " - "UNION ALL " - "SELECT " - " conname, contype, colnames, NULL, NULL, NULL " - "FROM other_constraints"; - - if (StringBuilderAppend(&query, "%s", stmt)) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } - - if (column_name_ != NULL) { - if (StringBuilderAppend(&query, "%s", " WHERE conname LIKE $3")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } - - params.push_back(std::string(column_name_)); - } + // Parameterized on schema_name + static std::string SchemaQuery() { + return std::string(kSchemaQueryAll) + " AND nspname = $1"; + } - auto result_helper = PqResultHelper{conn_, query.buffer, params, error_}; - StringBuilderReset(&query); + // Parameterized on schema_name, relkind, table_name + static std::string TablesQuery() { + return std::string(kTablesQueryAll) + " AND c.relname LIKE $3"; + } - RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); + // Parameterized on schema_name, table_name, column_name + static std::string ColumnsQuery() { + return std::string(kColumnsQueryAll) + " AND attr.attname LIKE $3"; + } - for (PqResultRow row : result_helper) { - const char* constraint_name = row[0].data; - const char* constraint_type = row[1].data; + // Parameterized on schema_name, table_name, column_name + static std::string ConstraintsQuery() { + return std::string(kConstraintsQueryAll) + " WHERE conname LIKE $3"; + } - CHECK_NA( - INTERNAL, - ArrowArrayAppendString(constraint_name_col_, ArrowCharView(constraint_name)), - error_); + std::string TableTypesArrayLiteral(const std::vector& table_types) { + std::stringstream table_types_bind; + table_types_bind << "{"; + int table_types_bind_len = 0; - CHECK_NA( - INTERNAL, - ArrowArrayAppendString(constraint_type_col_, ArrowCharView(constraint_type)), - error_); + if (table_types.empty()) { + for (const auto& item : kPgTableTypes) { + if (table_types_bind_len > 0) { + table_types_bind << ", "; + } - auto constraint_column_names = PqTextArrayToVector(std::string(row[2].data)); - for (const auto& constraint_column_name : constraint_column_names) { - CHECK_NA(INTERNAL, - ArrowArrayAppendString(constraint_column_name_col_, - ArrowCharView(constraint_column_name.c_str())), - error_); + table_types_bind << "\"" << item.second << "\""; + table_types_bind_len++; } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(constraint_column_names_col_), error_); - - if (!strcmp(constraint_type, "FOREIGN KEY")) { - assert(!row[3].is_null); - assert(!row[4].is_null); - assert(!row[5].is_null); - - const char* constraint_ftable_schema = row[3].data; - const char* constraint_ftable_name = row[4].data; - auto constraint_fcolumn_names = PqTextArrayToVector(std::string(row[5].data)); - for (const auto& constraint_fcolumn_name : constraint_fcolumn_names) { - CHECK_NA(INTERNAL, - ArrowArrayAppendString(fk_catalog_col_, ArrowCharView(PQdb(conn_))), - error_); - CHECK_NA(INTERNAL, - ArrowArrayAppendString(fk_db_schema_col_, - ArrowCharView(constraint_ftable_schema)), - error_); - CHECK_NA(INTERNAL, - ArrowArrayAppendString(fk_table_col_, - ArrowCharView(constraint_ftable_name)), - error_); - CHECK_NA(INTERNAL, - ArrowArrayAppendString(fk_column_name_col_, - ArrowCharView(constraint_fcolumn_name.c_str())), - error_); - - CHECK_NA(INTERNAL, ArrowArrayFinishElement(constraint_column_usage_items_), - error_); + } else { + for (auto type : table_types) { + const auto maybe_item = kPgTableTypes.find(std::string(type)); + if (maybe_item == kPgTableTypes.end()) { + continue; } - } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(constraint_column_usages_col_), error_); - CHECK_NA(INTERNAL, ArrowArrayFinishElement(table_constraints_items_), error_); - } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(table_constraints_col_), error_); - return ADBC_STATUS_OK; - } + if (table_types_bind_len > 0) { + table_types_bind << ", "; + } - AdbcStatusCode FinishArrowArray() { - CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array_, &na_error_), - &na_error_, error_); + table_types_bind << "\"" << maybe_item->second << "\""; + table_types_bind_len++; + } + } - return ADBC_STATUS_OK; + table_types_bind << "}"; + return table_types_bind.str(); } - - PGconn* conn_; - int depth_; - const char* catalog_; - const char* db_schema_; - const char* table_name_; - const char** table_types_; - const char* column_name_; - struct ArrowSchema* schema_; - struct ArrowArray* array_; - struct AdbcError* error_; - struct ArrowError na_error_; - struct ArrowArray* catalog_name_col_; - struct ArrowArray* catalog_db_schemas_col_; - struct ArrowArray* catalog_db_schemas_items_; - struct ArrowArray* db_schema_name_col_; - struct ArrowArray* db_schema_tables_col_; - struct ArrowArray* schema_table_items_; - struct ArrowArray* table_name_col_; - struct ArrowArray* table_type_col_; - struct ArrowArray* table_columns_col_; - struct ArrowArray* table_columns_items_; - struct ArrowArray* column_name_col_; - struct ArrowArray* column_position_col_; - struct ArrowArray* column_remarks_col_; - struct ArrowArray* table_constraints_col_; - struct ArrowArray* table_constraints_items_; - struct ArrowArray* constraint_name_col_; - struct ArrowArray* constraint_type_col_; - struct ArrowArray* constraint_column_names_col_; - struct ArrowArray* constraint_column_name_col_; - struct ArrowArray* constraint_column_usages_col_; - struct ArrowArray* constraint_column_usage_items_; - struct ArrowArray* fk_catalog_col_; - struct ArrowArray* fk_db_schema_col_; - struct ArrowArray* fk_table_col_; - struct ArrowArray* fk_column_name_col_; }; // A notice processor that does nothing with notices. In the future we can log @@ -641,117 +482,120 @@ AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) { return ADBC_STATUS_OK; } -AdbcStatusCode PostgresConnection::PostgresConnectionGetInfoImpl( - const uint32_t* info_codes, size_t info_codes_length, struct ArrowSchema* schema, - struct ArrowArray* array, struct AdbcError* error) { - RAISE_ADBC(adbc::driver::AdbcInitConnectionGetInfoSchema(schema, array).ToAdbc(error)); +AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, + const uint32_t* info_codes, + size_t info_codes_length, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!info_codes) { + info_codes = kSupportedInfoCodes; + info_codes_length = sizeof(kSupportedInfoCodes) / sizeof(kSupportedInfoCodes[0]); + } + + std::vector infos; for (size_t i = 0; i < info_codes_length; i++) { switch (info_codes[i]) { case ADBC_INFO_VENDOR_NAME: - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendString(array, info_codes[i], - "PostgreSQL") - .ToAdbc(error)); + infos.push_back({info_codes[i], std::string(VendorName())}); break; case ADBC_INFO_VENDOR_VERSION: { - const char* stmt = "SHOW server_version_num"; - auto result_helper = PqResultHelper{conn_, std::string(stmt), error}; - RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); - auto it = result_helper.begin(); - if (it == result_helper.end()) { - SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt); - return ADBC_STATUS_INTERNAL; + if (VendorName() == "Redshift") { + const std::array& version = VendorVersion(); + std::string version_string = std::to_string(version[0]) + "." + + std::to_string(version[1]) + "." + + std::to_string(version[2]); + infos.push_back({info_codes[i], std::move(version_string)}); + + } else { + // Gives a version in the form 140000 instead of 14.0.0 + const char* stmt = "SHOW server_version_num"; + auto result_helper = PqResultHelper{conn_, std::string(stmt)}; + RAISE_STATUS(error, result_helper.Execute()); + auto it = result_helper.begin(); + if (it == result_helper.end()) { + SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt); + return ADBC_STATUS_INTERNAL; + } + const char* server_version_num = (*it)[0].data; + infos.push_back({info_codes[i], server_version_num}); } - const char* server_version_num = (*it)[0].data; - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendString(array, info_codes[i], - server_version_num) - .ToAdbc(error)); break; } case ADBC_INFO_DRIVER_NAME: - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendString( - array, info_codes[i], "ADBC PostgreSQL Driver") - .ToAdbc(error)); + infos.push_back({info_codes[i], "ADBC PostgreSQL Driver"}); break; case ADBC_INFO_DRIVER_VERSION: // TODO(lidavidm): fill in driver version - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendString(array, info_codes[i], - "(unknown)") - .ToAdbc(error)); + infos.push_back({info_codes[i], "(unknown)"}); break; case ADBC_INFO_DRIVER_ARROW_VERSION: - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendString(array, info_codes[i], - NANOARROW_VERSION) - .ToAdbc(error)); + infos.push_back({info_codes[i], NANOARROW_VERSION}); break; case ADBC_INFO_DRIVER_ADBC_VERSION: - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendInt(array, info_codes[i], - ADBC_VERSION_1_1_0) - .ToAdbc(error)); + infos.push_back({info_codes[i], ADBC_VERSION_1_1_0}); break; default: // Ignore continue; } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error); } - struct ArrowError na_error = {0}; - CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array, &na_error), &na_error, - error); - + RAISE_ADBC(adbc::driver::MakeGetInfoStream(infos, out).ToAdbc(error)); return ADBC_STATUS_OK; } -AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, - const uint32_t* info_codes, - size_t info_codes_length, - struct ArrowArrayStream* out, - struct AdbcError* error) { - if (!info_codes) { - info_codes = kSupportedInfoCodes; - info_codes_length = sizeof(kSupportedInfoCodes) / sizeof(kSupportedInfoCodes[0]); - } - - struct ArrowSchema schema; - std::memset(&schema, 0, sizeof(schema)); - struct ArrowArray array; - std::memset(&array, 0, sizeof(array)); - - AdbcStatusCode status = PostgresConnectionGetInfoImpl(info_codes, info_codes_length, - &schema, &array, error); - if (status != ADBC_STATUS_OK) { - if (schema.release) schema.release(&schema); - if (array.release) array.release(&array); - return status; - } - - return BatchToArrayStream(&array, &schema, out, error); -} - AdbcStatusCode PostgresConnection::GetObjects( - struct AdbcConnection* connection, int depth, const char* catalog, - const char* db_schema, const char* table_name, const char** table_types, + struct AdbcConnection* connection, int c_depth, const char* catalog, + const char* db_schema, const char* table_name, const char** table_type, const char* column_name, struct ArrowArrayStream* out, struct AdbcError* error) { - struct ArrowSchema schema; - std::memset(&schema, 0, sizeof(schema)); - struct ArrowArray array; - std::memset(&array, 0, sizeof(array)); + PostgresGetObjectsHelper helper(conn_); + helper.SetEnableConstraints(VendorName() != "Redshift"); + + const auto catalog_filter = + catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt; + const auto schema_filter = + db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt; + const auto table_filter = + table_name ? std::make_optional(std::string_view(table_name)) : std::nullopt; + const auto column_filter = + column_name ? std::make_optional(std::string_view(column_name)) : std::nullopt; + std::vector table_type_filter; + while (table_type && *table_type) { + if (*table_type) { + table_type_filter.push_back(std::string_view(*table_type)); + } + table_type++; + } - PqGetObjectsHelper helper = - PqGetObjectsHelper(conn_, depth, catalog, db_schema, table_name, table_types, - column_name, &schema, &array, error); - AdbcStatusCode status = helper.GetObjects(); + using adbc::driver::GetObjectsDepth; - if (status != ADBC_STATUS_OK) { - if (schema.release) schema.release(&schema); - if (array.release) array.release(&array); - return status; + GetObjectsDepth depth = GetObjectsDepth::kColumns; + switch (c_depth) { + case ADBC_OBJECT_DEPTH_CATALOGS: + depth = GetObjectsDepth::kCatalogs; + break; + case ADBC_OBJECT_DEPTH_COLUMNS: + depth = GetObjectsDepth::kColumns; + break; + case ADBC_OBJECT_DEPTH_DB_SCHEMAS: + depth = GetObjectsDepth::kSchemas; + break; + case ADBC_OBJECT_DEPTH_TABLES: + depth = GetObjectsDepth::kTables; + break; + default: + return Status::InvalidArgument("[libpq] GetObjects: invalid depth ", c_depth) + .ToAdbc(error); } - return BatchToArrayStream(&array, &schema, out, error); + auto status = BuildGetObjects(&helper, depth, catalog_filter, schema_filter, + table_filter, column_filter, table_type_filter, out); + RAISE_STATUS(error, helper.Close()); + RAISE_STATUS(error, status); + + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value, @@ -760,12 +604,12 @@ AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value, if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) { output = PQdb(conn_); } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { - PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA", {}, error}; - RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); + PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA()"}; + RAISE_STATUS(error, result_helper.Execute()); auto it = result_helper.begin(); if (it == result_helper.end()) { - SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + SetError(error, + "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA()'"); return ADBC_STATUS_INTERNAL; } output = (*it)[0].data; @@ -931,10 +775,9 @@ AdbcStatusCode PostgresConnectionGetStatisticsImpl(PGconn* conn, const char* db_ std::string prev_table; { - PqResultHelper result_helper{ - conn, query, {db_schema, table_name ? table_name : "%"}, error}; - RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); + PqResultHelper result_helper{conn, query}; + RAISE_STATUS(error, + result_helper.Execute({db_schema, table_name ? table_name : "%"})); for (PqResultRow row : result_helper) { auto reltuples = row[5].ParseDouble(); @@ -1087,7 +930,8 @@ AdbcStatusCode PostgresConnection::GetStatistics(const char* catalog, return status; } - return BatchToArrayStream(&array, &schema, out, error); + adbc::driver::MakeArrayStream(&schema, &array, out); + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnectionGetStatisticNamesImpl(struct ArrowSchema* schema, @@ -1136,7 +980,9 @@ AdbcStatusCode PostgresConnection::GetStatisticNames(struct ArrowArrayStream* ou if (array.release) array.release(&array); return status; } - return BatchToArrayStream(&array, &schema, out, error); + + adbc::driver::MakeArrayStream(&schema, &array, out); + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, @@ -1146,39 +992,35 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, struct AdbcError* error) { AdbcStatusCode final_status = ADBC_STATUS_OK; + char* quoted = PQescapeIdentifier(conn_, table_name, strlen(table_name)); + std::string table_name_str(quoted); + PQfreemem(quoted); + + if (db_schema != nullptr) { + quoted = PQescapeIdentifier(conn_, db_schema, strlen(db_schema)); + table_name_str = std::string(quoted) + "." + table_name_str; + PQfreemem(quoted); + } + std::string query = "SELECT attname, atttypid " "FROM pg_catalog.pg_class AS cls " "INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid " "INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid " - "WHERE attr.attnum >= 0 AND cls.oid = $1::regclass::oid"; + "WHERE attr.attnum >= 0 AND cls.oid = $1::regclass::oid " + "ORDER BY attr.attnum"; - std::vector params; - if (db_schema != nullptr) { - params.push_back(std::string(db_schema) + "." + table_name); - } else { - params.push_back(table_name); - } + std::vector params = {table_name_str}; - PqResultHelper result_helper = - PqResultHelper{conn_, std::string(query.c_str()), params, error}; + PqResultHelper result_helper = PqResultHelper{conn_, std::string(query.c_str())}; - RAISE_ADBC(result_helper.Prepare()); - auto result = result_helper.Execute(); - if (result != ADBC_STATUS_OK) { - auto error_code = std::string(error->sqlstate, 5); - if ((error_code == "42P01") || (error_code == "42602")) { - return ADBC_STATUS_NOT_FOUND; - } - return result; - } + RAISE_STATUS(error, result_helper.Execute(params)); auto uschema = nanoarrow::UniqueSchema(); ArrowSchemaInit(uschema.get()); CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), result_helper.NumRows()), error); - ArrowError na_error; int row_counter = 0; for (auto row : result_helper) { const char* colname = row[0].data; @@ -1186,14 +1028,15 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, static_cast(std::strtol(row[1].data, /*str_end=*/nullptr, /*base=*/10)); PostgresType pg_type; - if (type_resolver_->Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) { - SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row_counter + 1, " (\"", colname, - "\") has unknown type code ", pg_oid); + if (type_resolver_->FindWithDefault(pg_oid, &pg_type) != NANOARROW_OK) { + SetError(error, "%s%d%s%s%s%" PRIu32, "Error resolving type code for column #", + row_counter + 1, " (\"", colname, "\") with oid ", pg_oid); final_status = ADBC_STATUS_NOT_IMPLEMENTED; break; } CHECK_NA(INTERNAL, - pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter]), + pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter], + std::string(VendorName())), error); row_counter++; } @@ -1202,54 +1045,17 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, return final_status; } -AdbcStatusCode PostgresConnectionGetTableTypesImpl(struct ArrowSchema* schema, - struct ArrowArray* array, - struct AdbcError* error) { - // See 'relkind' in https://www.postgresql.org/docs/current/catalog-pg-class.html - auto uschema = nanoarrow::UniqueSchema(); - ArrowSchemaInit(uschema.get()); - - CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get(), NANOARROW_TYPE_STRUCT), error); - CHECK_NA(INTERNAL, ArrowSchemaAllocateChildren(uschema.get(), /*num_columns=*/1), - error); - ArrowSchemaInit(uschema.get()->children[0]); - CHECK_NA(INTERNAL, - ArrowSchemaSetType(uschema.get()->children[0], NANOARROW_TYPE_STRING), error); - CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema.get()->children[0], "table_type"), error); - uschema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - - CHECK_NA(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), NULL), error); - CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error); - - for (auto const& table_type : kPgTableTypes) { - CHECK_NA(INTERNAL, - ArrowArrayAppendString(array->children[0], - ArrowCharView(table_type.first.c_str())), - error); - CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error); - } - - CHECK_NA(INTERNAL, ArrowArrayFinishBuildingDefault(array, NULL), error); - - uschema.move(schema); - return ADBC_STATUS_OK; -} - AdbcStatusCode PostgresConnection::GetTableTypes(struct AdbcConnection* connection, struct ArrowArrayStream* out, struct AdbcError* error) { - struct ArrowSchema schema; - std::memset(&schema, 0, sizeof(schema)); - struct ArrowArray array; - std::memset(&array, 0, sizeof(array)); - - AdbcStatusCode status = PostgresConnectionGetTableTypesImpl(&schema, &array, error); - if (status != ADBC_STATUS_OK) { - if (schema.release) schema.release(&schema); - if (array.release) array.release(&array); - return status; + std::vector table_types; + table_types.reserve(kPgTableTypes.size()); + for (auto const& table_type : kPgTableTypes) { + table_types.push_back(table_type.first); } - return BatchToArrayStream(&array, &schema, out, error); + + RAISE_STATUS(error, adbc::driver::MakeTableTypesStream(table_types, out)); + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::Init(struct AdbcDatabase* database, @@ -1331,10 +1137,12 @@ AdbcStatusCode PostgresConnection::SetOption(const char* key, const char* value, return ADBC_STATUS_OK; } else if (std::strcmp(key, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { // PostgreSQL doesn't accept a parameter here - PqResultHelper result_helper{ - conn_, std::string("SET search_path TO ") + value, {}, error}; - RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); + char* value_esc = PQescapeIdentifier(conn_, value, strlen(value)); + std::string query = std::string("SET search_path TO ") + value_esc; + PQfreemem(value_esc); + + PqResultHelper result_helper{conn_, query}; + RAISE_STATUS(error, result_helper.Execute()); return ADBC_STATUS_OK; } SetError(error, "%s%s", "[libpq] Unknown option ", key); @@ -1360,4 +1168,10 @@ AdbcStatusCode PostgresConnection::SetOptionInt(const char* key, int64_t value, return ADBC_STATUS_NOT_IMPLEMENTED; } +std::string_view PostgresConnection::VendorName() { return database_->VendorName(); } + +const std::array& PostgresConnection::VendorVersion() { + return database_->VendorVersion(); +} + } // namespace adbcpq diff --git a/c/driver/postgresql/connection.h b/c/driver/postgresql/connection.h index 5e45e90b75..7683875b5f 100644 --- a/c/driver/postgresql/connection.h +++ b/c/driver/postgresql/connection.h @@ -17,10 +17,11 @@ #pragma once +#include #include #include -#include +#include #include #include "postgres_type.h" @@ -73,13 +74,10 @@ class PostgresConnection { return type_resolver_; } bool autocommit() const { return autocommit_; } + std::string_view VendorName(); + const std::array& VendorVersion(); private: - AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, - size_t info_codes_length, - struct ArrowSchema* schema, - struct ArrowArray* array, - struct AdbcError* error); std::shared_ptr database_; std::shared_ptr type_resolver_; PGconn* conn_; diff --git a/c/driver/postgresql/copy/postgres_copy_reader_test.cc b/c/driver/postgresql/copy/postgres_copy_reader_test.cc index 0d85c256ec..7b9fe230f8 100644 --- a/c/driver/postgresql/copy/postgres_copy_reader_test.cc +++ b/c/driver/postgresql/copy/postgres_copy_reader_test.cc @@ -27,7 +27,7 @@ class PostgresCopyStreamTester { public: ArrowErrorCode Init(const PostgresType& root_type, ArrowError* error = nullptr) { NANOARROW_RETURN_NOT_OK(reader_.Init(root_type)); - NANOARROW_RETURN_NOT_OK(reader_.InferOutputSchema(error)); + NANOARROW_RETURN_NOT_OK(reader_.InferOutputSchema("PostgreSQL Tester", error)); NANOARROW_RETURN_NOT_OK(reader_.InitFieldReaders(error)); return NANOARROW_OK; } @@ -314,6 +314,41 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadDate) { ASSERT_EQ(data_buffer[1], 47482); } +TEST(PostgresCopyUtilsTest, PostgresCopyReadTime) { + ArrowBufferView data; + data.data.as_uint8 = kTestPgCopyTime; + data.size_bytes = sizeof(kTestPgCopyTime); + + auto col_type = PostgresType(PostgresTypeId::kTime); + PostgresType input_type(PostgresTypeId::kRecord); + input_type.AppendChild("col", col_type); + + PostgresCopyStreamTester tester; + ASSERT_EQ(tester.Init(input_type), NANOARROW_OK); + ASSERT_EQ(tester.ReadAll(&data), ENODATA); + ASSERT_EQ(data.data.as_uint8 - kTestPgCopyTime, sizeof(kTestPgCopyTime)); + ASSERT_EQ(data.size_bytes, 0); + + nanoarrow::UniqueArray array; + ASSERT_EQ(tester.GetArray(array.get()), NANOARROW_OK); + ASSERT_EQ(array->length, 4); + ASSERT_EQ(array->n_children, 1); + + auto validity = reinterpret_cast(array->children[0]->buffers[0]); + auto data_buffer = reinterpret_cast(array->children[0]->buffers[1]); + ASSERT_NE(validity, nullptr); + ASSERT_NE(data_buffer, nullptr); + + ASSERT_TRUE(ArrowBitGet(validity, 0)); + ASSERT_TRUE(ArrowBitGet(validity, 1)); + ASSERT_TRUE(ArrowBitGet(validity, 2)); + ASSERT_FALSE(ArrowBitGet(validity, 3)); + + ASSERT_EQ(data_buffer[0], 0); + ASSERT_EQ(data_buffer[1], 86399000000); + ASSERT_EQ(data_buffer[2], 49376123456); +} + TEST(PostgresCopyUtilsTest, PostgresCopyReadNumeric) { ArrowBufferView data; data.data.as_uint8 = kTestPgCopyNumeric; @@ -585,6 +620,86 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadEnum) { ASSERT_EQ(std::string(data_buffer + 2, 3), "sad"); } +TEST(PostgresCopyUtilsTest, PostgresCopyReadJson) { + ArrowBufferView data; + data.data.as_uint8 = kTestPgCopyJson; + data.size_bytes = sizeof(kTestPgCopyJson); + + auto col_type = PostgresType(PostgresTypeId::kJson); + PostgresType input_type(PostgresTypeId::kRecord); + input_type.AppendChild("col", col_type); + + PostgresCopyStreamTester tester; + ASSERT_EQ(tester.Init(input_type), NANOARROW_OK); + ASSERT_EQ(tester.ReadAll(&data), ENODATA); + ASSERT_EQ(data.data.as_uint8 - kTestPgCopyJson, sizeof(kTestPgCopyJson)); + ASSERT_EQ(data.size_bytes, 0); + + nanoarrow::UniqueArray array; + ASSERT_EQ(tester.GetArray(array.get()), NANOARROW_OK); + ASSERT_EQ(array->length, 3); + ASSERT_EQ(array->n_children, 1); + + auto validity = reinterpret_cast(array->children[0]->buffers[0]); + auto offsets = reinterpret_cast(array->children[0]->buffers[1]); + auto data_buffer = reinterpret_cast(array->children[0]->buffers[2]); + ASSERT_NE(validity, nullptr); + ASSERT_NE(data_buffer, nullptr); + + ASSERT_TRUE(ArrowBitGet(validity, 0)); + ASSERT_TRUE(ArrowBitGet(validity, 1)); + ASSERT_FALSE(ArrowBitGet(validity, 2)); + + ASSERT_EQ(offsets[0], 0); + ASSERT_EQ(offsets[1], 9); + ASSERT_EQ(offsets[2], 18); + ASSERT_EQ(offsets[3], 18); + + ASSERT_EQ(std::string(data_buffer, 9), "[1, 2, 3]"); + ASSERT_EQ(std::string(data_buffer + 9, 9), "[4, 5, 6]"); +} + +TEST(PostgresCopyUtilsTest, PostgresCopyReadJsonb) { + ArrowBufferView data; + data.data.as_uint8 = kTestPgCopyJsonb; + data.size_bytes = sizeof(kTestPgCopyJsonb); + + auto col_type = PostgresType(PostgresTypeId::kJsonb); + PostgresType input_type(PostgresTypeId::kRecord); + input_type.AppendChild("col", col_type); + + struct ArrowError error; + PostgresCopyStreamTester tester; + ASSERT_EQ(tester.Init(input_type), NANOARROW_OK); + ASSERT_EQ(tester.ReadAll(&data, &error), ENODATA) << error.message; + ASSERT_EQ(data.data.as_uint8 - kTestPgCopyJsonb, sizeof(kTestPgCopyJsonb)); + ASSERT_EQ(data.size_bytes, 0); + + nanoarrow::UniqueArray array; + + ASSERT_EQ(tester.GetArray(array.get(), &error), NANOARROW_OK) << error.message; + ASSERT_EQ(array->length, 3); + ASSERT_EQ(array->n_children, 1); + + auto validity = reinterpret_cast(array->children[0]->buffers[0]); + auto offsets = reinterpret_cast(array->children[0]->buffers[1]); + auto data_buffer = reinterpret_cast(array->children[0]->buffers[2]); + ASSERT_NE(validity, nullptr); + ASSERT_NE(data_buffer, nullptr); + + ASSERT_TRUE(ArrowBitGet(validity, 0)); + ASSERT_TRUE(ArrowBitGet(validity, 1)); + ASSERT_FALSE(ArrowBitGet(validity, 2)); + + ASSERT_EQ(offsets[0], 0); + ASSERT_EQ(offsets[1], 9); + ASSERT_EQ(offsets[2], 18); + ASSERT_EQ(offsets[3], 18); + + ASSERT_EQ(std::string(data_buffer, 9), "[1, 2, 3]"); + ASSERT_EQ(std::string(data_buffer + 9, 9), "[4, 5, 6]"); +} + TEST(PostgresCopyUtilsTest, PostgresCopyReadBinary) { ArrowBufferView data; data.data.as_uint8 = kTestPgCopyBinary; diff --git a/c/driver/postgresql/copy/postgres_copy_test_common.h b/c/driver/postgresql/copy/postgres_copy_test_common.h index d685486626..8872ada6d0 100644 --- a/c/driver/postgresql/copy/postgres_copy_test_common.h +++ b/c/driver/postgresql/copy/postgres_copy_test_common.h @@ -21,6 +21,10 @@ namespace adbcpq { +// New cases can be genereated using: +// psql --host 127.0.0.1 --port 5432 --username postgres -c "COPY (SELECT ...) TO STDOUT +// WITH (FORMAT binary);" > test.copy Rscript -e "dput(brio::read_file_raw('test.copy'))" + // COPY (SELECT CAST("col" AS BOOLEAN) AS "col" FROM ( VALUES (TRUE), (FALSE), (NULL)) AS // drvd("col")) TO STDOUT; static const uint8_t kTestPgCopyBoolean[] = { @@ -81,6 +85,15 @@ static const uint8_t kTestPgCopyDate[] = { 0x04, 0xff, 0xff, 0x71, 0x54, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x8e, 0xad, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; +// COPY (SELECT CAST("col" AS TIME) AS "col" FROM ( VALUES ('00:00:00'), ('23:59:59'), +// ('13:42:57.123456'), (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT binary); +static const uint8_t kTestPgCopyTime[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x14, 0x1d, 0xc8, 0x1d, 0xc0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x0b, 0x7f, 0x0b, 0xda, 0x40, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + // COPY (SELECT CAST(col AS TIMESTAMP) FROM ( VALUES ('1900-01-01 12:34:56'), // ('2100-01-01 12:34:56'), (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT BINARY); static const uint8_t kTestPgCopyTimestamp[] = { @@ -107,6 +120,24 @@ static const uint8_t kTestPgCopyText[] = { 0x03, 0x61, 0x62, 0x63, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x31, 0x32, 0x33, 0x34, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; +// COPY (SELECT CAST(col AS json) AS col FROM (VALUES ('[1, 2, 3]'), ('[4, 5, 6]'), +// (NULL::json)) AS drvd(col)) TO STDOUT WITH (FORMAT binary); +static const uint8_t kTestPgCopyJson[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x09, 0x5b, 0x31, 0x2c, 0x20, 0x32, 0x2c, 0x20, 0x33, 0x5d, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x09, 0x5b, 0x34, 0x2c, 0x20, 0x35, 0x2c, 0x20, 0x36, + 0x5d, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +// COPY (SELECT CAST(col AS jsonb) AS col FROM (VALUES ('[1, 2, 3]'), ('[4, 5, 6]'), +// (NULL::jsonb)) AS drvd(col)) TO STDOUT WITH (FORMAT binary); +static const uint8_t kTestPgCopyJsonb[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x0a, 0x01, 0x5b, 0x31, 0x2c, 0x20, 0x32, 0x2c, 0x20, 0x33, 0x5d, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0a, 0x01, 0x5b, 0x34, 0x2c, 0x20, 0x35, 0x2c, + 0x20, 0x36, 0x5d, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + // COPY (SELECT CAST("col" AS BYTEA) AS "col" FROM ( VALUES (''), ('\x0001'), // ('\x01020304'), ('\xFEFF'), (NULL)) AS drvd("col")) TO STDOUT // WITH (FORMAT binary); diff --git a/c/driver/postgresql/copy/postgres_copy_writer_test.cc b/c/driver/postgresql/copy/postgres_copy_writer_test.cc index 688e4ea79e..5010848cf5 100644 --- a/c/driver/postgresql/copy/postgres_copy_writer_test.cc +++ b/c/driver/postgresql/copy/postgres_copy_writer_test.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include @@ -24,16 +25,20 @@ #include "postgres_copy_test_common.h" #include "postgresql/copy/writer.h" +#include "postgresql/database.h" #include "validation/adbc_validation_util.h" +using adbc_validation::IsOkStatus; + namespace adbcpq { class PostgresCopyStreamWriteTester { public: ArrowErrorCode Init(struct ArrowSchema* schema, struct ArrowArray* array, + const PostgresTypeResolver& type_resolver, struct ArrowError* error = nullptr) { NANOARROW_RETURN_NOT_OK(writer_.Init(schema)); - NANOARROW_RETURN_NOT_OK(writer_.InitFieldWriters(error)); + NANOARROW_RETURN_NOT_OK(writer_.InitFieldWriters(type_resolver, error)); NANOARROW_RETURN_NOT_OK(writer_.SetArray(array)); return NANOARROW_OK; } @@ -67,7 +72,42 @@ class PostgresCopyStreamWriteTester { PostgresCopyStreamWriter writer_; }; -TEST(PostgresCopyUtilsTest, PostgresCopyWriteBoolean) { +static AdbcStatusCode SetupDatabase(struct AdbcDatabase* database, + struct AdbcError* error) { + const char* uri = std::getenv("ADBC_POSTGRESQL_TEST_URI"); + if (!uri) { + ADD_FAILURE() << "Must provide env var ADBC_POSTGRESQL_TEST_URI"; + return ADBC_STATUS_INVALID_ARGUMENT; + } + return AdbcDatabaseSetOption(database, "uri", uri, error); +} + +class PostgresCopyTest : public ::testing::Test { + public: + void SetUp() override { + ASSERT_THAT(AdbcDatabaseNew(&database_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(SetupDatabase(&database_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcDatabaseInit(&database_, &error_), IsOkStatus(&error_)); + + const auto pg_db = + *reinterpret_cast*>(database_.private_data); + type_resolver_ = pg_db->type_resolver(); + } + void TearDown() override { + if (database_.private_data) { + ASSERT_THAT(AdbcDatabaseRelease(&database_, &error_), IsOkStatus(&error_)); + } + + if (error_.release) error_.release(&error_); + } + + protected: + struct AdbcError error_ = {}; + struct AdbcDatabase database_ = {}; + std::shared_ptr type_resolver_; +}; + +TEST_F(PostgresCopyTest, PostgresCopyWriteBoolean) { adbc_validation::Handle schema; adbc_validation::Handle array; adbc_validation::Handle buffer; @@ -79,7 +119,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteBoolean) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -92,7 +132,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteBoolean) { } } -TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt8) { +TEST_F(PostgresCopyTest, PostgresCopyWriteInt8) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -103,7 +143,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt8) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -116,7 +156,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt8) { } } -TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt16) { +TEST_F(PostgresCopyTest, PostgresCopyWriteInt16) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -127,7 +167,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt16) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -140,7 +180,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt16) { } } -TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt32) { +TEST_F(PostgresCopyTest, PostgresCopyWriteInt32) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -151,7 +191,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt32) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -164,7 +204,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt32) { } } -TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt64) { +TEST_F(PostgresCopyTest, PostgresCopyWriteInt64) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -175,7 +215,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt64) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -188,7 +228,106 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInt64) { } } -TEST(PostgresCopyUtilsTest, PostgresCopyWriteReal) { +// COPY (SELECT CAST("col" AS SMALLINT) AS "col" FROM ( VALUES (0), (255), +// (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT binary); +static const uint8_t kTestPgCopyUInt8[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, + 0x00, 0xff, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +TEST_F(PostgresCopyTest, PostgresCopyWriteUInt8) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + ASSERT_EQ(adbc_validation::MakeSchema(&schema.value, {{"col", NANOARROW_TYPE_UINT8}}), + ADBC_STATUS_OK); + ASSERT_EQ(adbc_validation::MakeBatch( + &schema.value, &array.value, &na_error, + {0, (std::numeric_limits::max)(), std::nullopt}), + ADBC_STATUS_OK); + + PostgresCopyStreamWriteTester tester; + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); + ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); + + const struct ArrowBuffer buf = tester.WriteBuffer(); + // The last 2 bytes of a message can be transmitted via PQputCopyData + // so no need to test those bytes from the Writer + constexpr size_t buf_size = sizeof(kTestPgCopyUInt8) - 2; + ASSERT_EQ(buf.size_bytes, buf_size); + for (size_t i = 0; i < buf_size; i++) { + ASSERT_EQ(buf.data[i], kTestPgCopyUInt8[i]); + } +} + +// COPY (SELECT CAST("col" AS INTEGER) AS "col" FROM ( VALUES (0), (65535), +// (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT binary); +static const uint8_t kTestPgCopyUInt16[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x00, 0xff, 0xff, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +TEST_F(PostgresCopyTest, PostgresCopyWriteUInt16) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + ASSERT_EQ(adbc_validation::MakeSchema(&schema.value, {{"col", NANOARROW_TYPE_UINT16}}), + ADBC_STATUS_OK); + ASSERT_EQ(adbc_validation::MakeBatch( + &schema.value, &array.value, &na_error, + {0, (std::numeric_limits::max)(), std::nullopt}), + ADBC_STATUS_OK); + + PostgresCopyStreamWriteTester tester; + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); + ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); + + const struct ArrowBuffer buf = tester.WriteBuffer(); + // The last 2 bytes of a message can be transmitted via PQputCopyData + // so no need to test those bytes from the Writer + constexpr size_t buf_size = sizeof(kTestPgCopyUInt16) - 2; + ASSERT_EQ(buf.size_bytes, buf_size); + for (size_t i = 0; i < buf_size; i++) { + ASSERT_EQ(buf.data[i], kTestPgCopyUInt16[i]); + } +} + +// COPY (SELECT CAST("col" AS BIGINT) AS "col" FROM ( VALUES (0), (2^32-1), +// (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT binary); +static const uint8_t kTestPgCopyUInt32[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +TEST_F(PostgresCopyTest, PostgresCopyWriteUInt32) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + ASSERT_EQ(adbc_validation::MakeSchema(&schema.value, {{"col", NANOARROW_TYPE_UINT32}}), + ADBC_STATUS_OK); + ASSERT_EQ(adbc_validation::MakeBatch( + &schema.value, &array.value, &na_error, + {0, (std::numeric_limits::max)(), std::nullopt}), + ADBC_STATUS_OK); + + PostgresCopyStreamWriteTester tester; + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); + ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); + + const struct ArrowBuffer buf = tester.WriteBuffer(); + // The last 2 bytes of a message can be transmitted via PQputCopyData + // so no need to test those bytes from the Writer + constexpr size_t buf_size = sizeof(kTestPgCopyUInt32) - 2; + ASSERT_EQ(buf.size_bytes, buf_size); + for (size_t i = 0; i < buf_size; i++) { + ASSERT_EQ(buf.data[i], kTestPgCopyUInt32[i]); + } +} + +TEST_F(PostgresCopyTest, PostgresCopyWriteReal) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -199,7 +338,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteReal) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -212,7 +351,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteReal) { } } -TEST(PostgresCopyUtilsTest, PostgresCopyWriteDoublePrecision) { +TEST_F(PostgresCopyTest, PostgresCopyWriteDoublePrecision) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -223,7 +362,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteDoublePrecision) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -236,7 +375,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteDoublePrecision) { } } -TEST(PostgresCopyUtilsTest, PostgresCopyWriteDate) { +TEST_F(PostgresCopyTest, PostgresCopyWriteDate) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -247,7 +386,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteDate) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -260,6 +399,37 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteDate) { } } +TEST_F(PostgresCopyTest, PostgresCopyWriteTime) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + + const enum ArrowTimeUnit unit = NANOARROW_TIME_UNIT_MICRO; + const auto values = + std::vector>{0, 86399000000, 49376123456, std::nullopt}; + + ArrowSchemaInit(&schema.value); + ArrowSchemaSetTypeStruct(&schema.value, 1); + ArrowSchemaSetTypeDateTime(schema->children[0], NANOARROW_TYPE_TIME64, unit, nullptr); + ArrowSchemaSetName(schema->children[0], "col"); + ASSERT_EQ( + adbc_validation::MakeBatch(&schema.value, &array.value, &na_error, values), + ADBC_STATUS_OK); + + PostgresCopyStreamWriteTester tester; + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); + ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); + + const struct ArrowBuffer buf = tester.WriteBuffer(); + // The last 2 bytes of a message can be transmitted via PQputCopyData + // so no need to test those bytes from the Writer + constexpr size_t buf_size = sizeof(kTestPgCopyTime) - 2; + ASSERT_EQ(buf.size_bytes, buf_size); + for (size_t i = 0; i < buf_size; i++) { + ASSERT_EQ(buf.data[i], kTestPgCopyTime[i]); + } +} + // This buffer is similar to the read variant above but removes special values // nan, ±inf as they are not supported via the Arrow Decimal types // COPY (SELECT CAST(col AS NUMERIC) AS col FROM ( VALUES (NULL), (-123.456), @@ -275,7 +445,7 @@ static uint8_t kTestPgCopyNumericWrite[] = { 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x7b, 0x11, 0xd0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0xff, 0xff}; -TEST(PostgresCopyUtilsTest, PostgresCopyWriteNumeric) { +TEST_F(PostgresCopyTest, PostgresCopyWriteNumeric) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -306,16 +476,15 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteNumeric) { ArrowSchemaInit(&schema.value); ASSERT_EQ(ArrowSchemaSetTypeStruct(&schema.value, 1), 0); - ASSERT_EQ( - PrivateArrowSchemaSetTypeDecimal(schema.value.children[0], type, precision, scale), - 0); + ASSERT_EQ(ArrowSchemaSetTypeDecimal(schema.value.children[0], type, precision, scale), + 0); ASSERT_EQ(ArrowSchemaSetName(schema.value.children[0], "col"), 0); ASSERT_EQ(adbc_validation::MakeBatch(&schema.value, &array.value, &na_error, values), ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -332,7 +501,29 @@ using TimestampTestParamType = std::tuple>>; class PostgresCopyWriteTimestampTest - : public testing::TestWithParam {}; + : public testing::TestWithParam { + void SetUp() override { + ASSERT_THAT(AdbcDatabaseNew(&database_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(SetupDatabase(&database_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcDatabaseInit(&database_, &error_), IsOkStatus(&error_)); + + const auto pg_db = + *reinterpret_cast*>(database_.private_data); + type_resolver_ = pg_db->type_resolver(); + } + void TearDown() override { + if (database_.private_data) { + ASSERT_THAT(AdbcDatabaseRelease(&database_, &error_), IsOkStatus(&error_)); + } + + if (error_.release) error_.release(&error_); + } + + protected: + struct AdbcError error_ = {}; + struct AdbcDatabase database_ = {}; + std::shared_ptr type_resolver_; +}; TEST_P(PostgresCopyWriteTimestampTest, WritesProperBufferValues) { adbc_validation::Handle schema; @@ -355,7 +546,7 @@ TEST_P(PostgresCopyWriteTimestampTest, WritesProperBufferValues) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -402,7 +593,7 @@ static const std::vector ts_values{ INSTANTIATE_TEST_SUITE_P(PostgresCopyWriteTimestamp, PostgresCopyWriteTimestampTest, testing::ValuesIn(ts_values)); -TEST(PostgresCopyUtilsTest, PostgresCopyWriteInterval) { +TEST_F(PostgresCopyTest, PostgresCopyWriteInterval) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -432,7 +623,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteInterval) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -459,7 +650,29 @@ using DurationTestParamType = std::tuple>>; class PostgresCopyWriteDurationTest - : public testing::TestWithParam {}; + : public testing::TestWithParam { + void SetUp() override { + ASSERT_THAT(AdbcDatabaseNew(&database_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(SetupDatabase(&database_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcDatabaseInit(&database_, &error_), IsOkStatus(&error_)); + + const auto pg_db = + *reinterpret_cast*>(database_.private_data); + type_resolver_ = pg_db->type_resolver(); + } + void TearDown() override { + if (database_.private_data) { + ASSERT_THAT(AdbcDatabaseRelease(&database_, &error_), IsOkStatus(&error_)); + } + + if (error_.release) error_.release(&error_); + } + + protected: + struct AdbcError error_ = {}; + struct AdbcDatabase database_ = {}; + std::shared_ptr type_resolver_; +}; TEST_P(PostgresCopyWriteDurationTest, WritesProperBufferValues) { adbc_validation::Handle schema; @@ -480,7 +693,7 @@ TEST_P(PostgresCopyWriteDurationTest, WritesProperBufferValues) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -503,7 +716,7 @@ static const std::vector duration_params{ INSTANTIATE_TEST_SUITE_P(PostgresCopyWriteDuration, PostgresCopyWriteDurationTest, testing::ValuesIn(duration_params)); -TEST(PostgresCopyUtilsTest, PostgresCopyWriteString) { +TEST_F(PostgresCopyTest, PostgresCopyWriteString) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -514,7 +727,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteString) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -527,7 +740,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteString) { } } -TEST(PostgresCopyUtilsTest, PostgresCopyWriteLargeString) { +TEST_F(PostgresCopyTest, PostgresCopyWriteLargeString) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -539,7 +752,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteLargeString) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -552,7 +765,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteLargeString) { } } -TEST(PostgresCopyUtilsTest, PostgresCopyWriteBinary) { +TEST_F(PostgresCopyTest, PostgresCopyWriteBinary) { adbc_validation::Handle schema; adbc_validation::Handle array; struct ArrowError na_error; @@ -568,7 +781,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteBinary) { ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); const struct ArrowBuffer buf = tester.WriteBuffer(); @@ -581,7 +794,243 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteBinary) { } } -TEST(PostgresCopyUtilsTest, PostgresCopyWriteMultiBatch) { +class PostgresCopyListTest : public testing::TestWithParam { + public: + void SetUp() override { + ASSERT_THAT(AdbcDatabaseNew(&database_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(SetupDatabase(&database_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcDatabaseInit(&database_, &error_), IsOkStatus(&error_)); + + const auto pg_db = + *reinterpret_cast*>(database_.private_data); + type_resolver_ = pg_db->type_resolver(); + } + void TearDown() override { + if (database_.private_data) { + ASSERT_THAT(AdbcDatabaseRelease(&database_, &error_), IsOkStatus(&error_)); + } + + if (error_.release) error_.release(&error_); + } + + protected: + struct AdbcError error_ = {}; + struct AdbcDatabase database_ = {}; + std::shared_ptr type_resolver_; +}; + +// COPY (SELECT CAST("col" AS SMALLINT ARRAY) AS "col" FROM ( VALUES ('{-123, -1}'), +// ('{0, 1, 123}'), (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT binary); +static const uint8_t kTestPgCopySmallIntegerArray[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x02, 0x00, + 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0xff, 0x85, 0x00, 0x00, 0x00, 0x02, 0xff, + 0xff, 0x00, 0x01, 0x00, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x7b, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +TEST_P(PostgresCopyListTest, PostgresCopyWriteListSmallInt) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + + ASSERT_EQ(adbc_validation::MakeSchema( + &schema.value, {adbc_validation::SchemaField::Nested( + "col", GetParam(), {{"item", NANOARROW_TYPE_INT16}})}), + ADBC_STATUS_OK); + + ASSERT_EQ(adbc_validation::MakeBatch>( + &schema.value, &array.value, &na_error, + {std::vector{-123, -1}, std::vector{0, 1, 123}, + std::nullopt}), + ADBC_STATUS_OK); + + PostgresCopyStreamWriteTester tester; + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); + ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); + + const struct ArrowBuffer buf = tester.WriteBuffer(); + // The last 2 bytes of a message can be transmitted via PQputCopyData + // so no need to test those bytes from the Writer + constexpr size_t buf_size = sizeof(kTestPgCopySmallIntegerArray) - 2; + ASSERT_EQ(buf.size_bytes, buf_size); + for (size_t i = 0; i < buf_size; i++) { + ASSERT_EQ(buf.data[i], kTestPgCopySmallIntegerArray[i]) << "failure at index " << i; + } +} + +TEST_P(PostgresCopyListTest, PostgresCopyWriteListInteger) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + + ASSERT_EQ(adbc_validation::MakeSchema( + &schema.value, {adbc_validation::SchemaField::Nested( + "col", GetParam(), {{"item", NANOARROW_TYPE_INT32}})}), + ADBC_STATUS_OK); + + ASSERT_EQ(adbc_validation::MakeBatch>( + &schema.value, &array.value, &na_error, + {std::vector{-123, -1}, std::vector{0, 1, 123}, + std::nullopt}), + ADBC_STATUS_OK); + + PostgresCopyStreamWriteTester tester; + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); + ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); + + const struct ArrowBuffer buf = tester.WriteBuffer(); + // The last 2 bytes of a message can be transmitted via PQputCopyData + // so no need to test those bytes from the Writer + constexpr size_t buf_size = sizeof(kTestPgCopyIntegerArray) - 2; + ASSERT_EQ(buf.size_bytes, buf_size); + for (size_t i = 0; i < buf_size; i++) { + ASSERT_EQ(buf.data[i], kTestPgCopyIntegerArray[i]) << "failure at index " << i; + } +} + +// COPY (SELECT CAST("col" AS BIGINT ARRAY) AS "col" FROM ( VALUES ('{-123, -1}'), ('{0, +// 1, 123}'), (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT binary); +static const uint8_t kTestPgCopyBigIntegerArray[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x02, 0x00, + 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x85, 0x00, 0x00, 0x00, 0x08, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x7b, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +TEST_P(PostgresCopyListTest, PostgresCopyWriteListBigInt) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + + ASSERT_EQ(adbc_validation::MakeSchema( + &schema.value, {adbc_validation::SchemaField::Nested( + "col", GetParam(), {{"item", NANOARROW_TYPE_INT64}})}), + ADBC_STATUS_OK); + + ASSERT_EQ(adbc_validation::MakeBatch>( + &schema.value, &array.value, &na_error, + {std::vector{-123, -1}, std::vector{0, 1, 123}, + std::nullopt}), + ADBC_STATUS_OK); + + PostgresCopyStreamWriteTester tester; + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); + ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); + + const struct ArrowBuffer buf = tester.WriteBuffer(); + // The last 2 bytes of a message can be transmitted via PQputCopyData + // so no need to test those bytes from the Writer + constexpr size_t buf_size = sizeof(kTestPgCopyBigIntegerArray) - 2; + ASSERT_EQ(buf.size_bytes, buf_size); + for (size_t i = 0; i < buf_size; i++) { + ASSERT_EQ(buf.data[i], kTestPgCopyBigIntegerArray[i]) << "failure at index " << i; + } +} + +// COPY (SELECT CAST("col" AS TEXT ARRAY) AS "col" FROM ( VALUES ('{"foo", "bar"}'), +// ('{"baz", "qux", "quux"}'), (NULL)) AS drvd("col")) TO '/tmp/pgout.data' WITH (FORMAT +// binary); +static const uint8_t kTestPgCopyTextArray[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x22, 0x00, + 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, + 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x66, 0x6f, 0x6f, + 0x00, 0x00, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2a, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x19, 0x00, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x62, 0x61, + 0x7a, 0x00, 0x00, 0x00, 0x03, 0x71, 0x75, 0x78, 0x00, 0x00, 0x00, 0x04, 0x71, + 0x75, 0x75, 0x78, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +TEST_P(PostgresCopyListTest, PostgresCopyWriteListVarchar) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + + ASSERT_EQ( + adbc_validation::MakeSchema( + &schema.value, {adbc_validation::SchemaField::Nested( + "col", GetParam(), {{"item", NANOARROW_TYPE_STRING}})}), + ADBC_STATUS_OK); + + ASSERT_EQ(adbc_validation::MakeBatch>( + &schema.value, &array.value, &na_error, + {std::vector{"foo", "bar"}, + std::vector{"baz", "qux", "quux"}, std::nullopt}), + ADBC_STATUS_OK); + + PostgresCopyStreamWriteTester tester; + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); + ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); + + const struct ArrowBuffer buf = tester.WriteBuffer(); + // The last 2 bytes of a message can be transmitted via PQputCopyData + // so no need to test those bytes from the Writer + constexpr size_t buf_size = sizeof(kTestPgCopyTextArray) - 2; + ASSERT_EQ(buf.size_bytes, buf_size); + for (size_t i = 0; i < buf_size; i++) { + ASSERT_EQ(buf.data[i], kTestPgCopyTextArray[i]) << "failure at index " << i; + } +} + +INSTANTIATE_TEST_SUITE_P(ArrowListTypes, PostgresCopyListTest, + testing::Values(NANOARROW_TYPE_LIST, NANOARROW_TYPE_LARGE_LIST)); + +// COPY (SELECT CAST("col" AS INTEGER ARRAY) AS "col" FROM ( VALUES ('{1, 2}'), +// ('{-1, -2}'), (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT BINARY); +static const uint8_t kTestPgCopyFixedSizeIntegerArray[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x02, 0x00, + 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x02, 0x00, + 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, + 0x04, 0xff, 0xff, 0xff, 0xfe, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; +TEST_F(PostgresCopyTest, PostgresCopyWriteFixedSizeListInteger) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + + ASSERT_EQ(ArrowSchemaInitFromType(&schema.value, NANOARROW_TYPE_STRUCT), NANOARROW_OK); + ASSERT_EQ(ArrowSchemaAllocateChildren(&schema.value, 1), NANOARROW_OK); + + ArrowSchemaInit(schema->children[0]); + ASSERT_EQ( + ArrowSchemaSetTypeFixedSize(schema->children[0], NANOARROW_TYPE_FIXED_SIZE_LIST, 2), + NANOARROW_OK); + ASSERT_EQ(ArrowSchemaSetName(schema->children[0], "col"), NANOARROW_OK); + ASSERT_EQ(ArrowSchemaSetType(schema->children[0]->children[0], NANOARROW_TYPE_INT32), + NANOARROW_OK); + + ASSERT_EQ(adbc_validation::MakeBatch>( + &schema.value, &array.value, &na_error, + {std::vector{1, 2}, std::vector{-1, -2}, std::nullopt}), + ADBC_STATUS_OK); + + PostgresCopyStreamWriteTester tester; + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); + ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); + + const struct ArrowBuffer buf = tester.WriteBuffer(); + // The last 2 bytes of a message can be transmitted via PQputCopyData + // so no need to test those bytes from the Writer + constexpr size_t buf_size = sizeof(kTestPgCopyFixedSizeIntegerArray) - 2; + ASSERT_EQ(buf.size_bytes, buf_size); + for (size_t i = 0; i < buf_size; i++) { + ASSERT_EQ(buf.data[i], kTestPgCopyFixedSizeIntegerArray[i]) + << "failure at index " << i; + } +} + +TEST_F(PostgresCopyTest, PostgresCopyWriteMultiBatch) { // Regression test for https://github.com/apache/arrow-adbc/issues/1310 adbc_validation::Handle schema; adbc_validation::Handle array; @@ -593,7 +1042,7 @@ TEST(PostgresCopyUtilsTest, PostgresCopyWriteMultiBatch) { NANOARROW_OK); PostgresCopyStreamWriteTester tester; - ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); struct ArrowBuffer buf = tester.WriteBuffer(); diff --git a/c/driver/postgresql/copy/reader.h b/c/driver/postgresql/copy/reader.h index 9486df93cb..07f91d545e 100644 --- a/c/driver/postgresql/copy/reader.h +++ b/c/driver/postgresql/copy/reader.h @@ -443,6 +443,47 @@ class PostgresCopyBinaryFieldReader : public PostgresCopyFieldReader { } }; +/// Postgres JSONB emits as the JSON string prefixed with a version number +/// (https://github.com/postgres/postgres/blob/3f44959f47460fb350d25d760cf2384f9aa14e9a/src/backend/utils/adt/jsonb.c#L80-L87 +/// ) Currently there is only one version, so functionally this is a just string prefixed +/// with 0x01. +class PostgresCopyJsonbFieldReader : public PostgresCopyFieldReader { + public: + ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array, + ArrowError* error) override { + // -1 for NULL (0 would be empty string) + if (field_size_bytes < 0) { + return ArrowArrayAppendNull(array, 1); + } + + if (field_size_bytes > data->size_bytes) { + ArrowErrorSet(error, "Expected %d bytes of field data but got %d bytes of input", + static_cast(field_size_bytes), + static_cast(data->size_bytes)); // NOLINT(runtime/int) + return EINVAL; + } + + int8_t version; + NANOARROW_RETURN_NOT_OK(ReadChecked(data, &version, error)); + if (version != 1) { + ArrowErrorSet(error, "Expected JSONB binary version 0x01 but got %d", + static_cast(version)); + return NANOARROW_OK; + } + + field_size_bytes -= 1; + NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, data->data.data, field_size_bytes)); + data->data.as_uint8 += field_size_bytes; + data->size_bytes -= field_size_bytes; + + int32_t* offsets = reinterpret_cast(offsets_->data); + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppendInt32(offsets_, offsets[array->length] + field_size_bytes)); + + return AppendValid(array); + } +}; + class PostgresCopyArrayFieldReader : public PostgresCopyFieldReader { public: void InitChild(std::unique_ptr child) { @@ -774,11 +815,15 @@ static inline ArrowErrorCode MakeCopyFieldReader( case PostgresTypeId::kBpchar: case PostgresTypeId::kName: case PostgresTypeId::kEnum: + case PostgresTypeId::kJson: *out = std::make_unique(); return NANOARROW_OK; case PostgresTypeId::kNumeric: *out = std::make_unique(); return NANOARROW_OK; + case PostgresTypeId::kJsonb: + *out = std::make_unique(); + return NANOARROW_OK; default: return ErrorCantConvert(error, pg_type, schema_view); } @@ -852,8 +897,13 @@ static inline ArrowErrorCode MakeCopyFieldReader( } case NANOARROW_TYPE_TIME64: { - *out = std::make_unique>(); - return NANOARROW_OK; + switch (pg_type.type_id()) { + case PostgresTypeId::kTime: + *out = std::make_unique>(); + return NANOARROW_OK; + default: + return ErrorCantConvert(error, pg_type, schema_view); + } } case NANOARROW_TYPE_TIMESTAMP: @@ -922,10 +972,11 @@ class PostgresCopyStreamReader { return NANOARROW_OK; } - ArrowErrorCode InferOutputSchema(ArrowError* error) { + ArrowErrorCode InferOutputSchema(const std::string& vendor_name, ArrowError* error) { schema_.reset(); ArrowSchemaInit(schema_.get()); - NANOARROW_RETURN_NOT_OK(root_reader_.InputType().SetSchema(schema_.get())); + NANOARROW_RETURN_NOT_OK( + root_reader_.InputType().SetSchema(schema_.get(), vendor_name)); return NANOARROW_OK; } diff --git a/c/driver/postgresql/copy/writer.h b/c/driver/postgresql/copy/writer.h index 99791ad433..e88ed691cd 100644 --- a/c/driver/postgresql/copy/writer.h +++ b/c/driver/postgresql/copy/writer.h @@ -28,6 +28,7 @@ #include +#include "../connection.h" #include "../postgres_util.h" #include "copy_common.h" @@ -91,13 +92,20 @@ class PostgresCopyFieldWriter { public: virtual ~PostgresCopyFieldWriter() {} - void Init(struct ArrowArrayView* array_view) { array_view_ = array_view; }; + template + static std::unique_ptr Create(struct ArrowArrayView* array_view, Params&&... args) { + auto writer = std::make_unique(std::forward(args)...); + writer->Init(array_view); + return writer; + } virtual ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) { return ENOTSUP; } protected: + virtual void Init(struct ArrowArrayView* array_view) { array_view_ = array_view; }; + struct ArrowArrayView* array_view_; std::vector> children_; }; @@ -105,9 +113,7 @@ class PostgresCopyFieldWriter { class PostgresCopyFieldTupleWriter : public PostgresCopyFieldWriter { public: void AppendChild(std::unique_ptr child) { - int64_t child_i = static_cast(children_.size()); children_.push_back(std::move(child)); - children_[child_i]->Init(array_view_->children[child_i]); } ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override { @@ -215,7 +221,7 @@ class PostgresCopyIntervalFieldWriter : public PostgresCopyFieldWriter { template class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter { public: - PostgresCopyNumericFieldWriter(int32_t precision, int32_t scale) + PostgresCopyNumericFieldWriter(int32_t precision, int32_t scale) : precision_{precision}, scale_{scale} {} ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override { @@ -437,6 +443,71 @@ class PostgresCopyBinaryDictFieldWriter : public PostgresCopyFieldWriter { } }; +template +class PostgresCopyListFieldWriter : public PostgresCopyFieldWriter { + public: + explicit PostgresCopyListFieldWriter(uint32_t child_oid, + std::unique_ptr child) + : child_oid_{child_oid}, child_{std::move(child)} {} + + ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override { + if (index >= array_view_->length) { + return ENODATA; + } + + constexpr int32_t ndim = 1; + constexpr int32_t has_null_flags = 0; + + // TODO: the LARGE_LIST should use 64 bit indexes + int32_t start, end; + if constexpr (IsFixedSize) { + start = index * array_view_->layout.child_size_elements; + end = start + array_view_->layout.child_size_elements; + } else { + start = ArrowArrayViewListChildOffset(array_view_, index); + end = ArrowArrayViewListChildOffset(array_view_, index + 1); + } + + const int32_t dim = end - start; + constexpr int32_t lb = 1; + + // for children of a fixed size T we could avoid the use of a temporary buffer + /// and theoretically just write + // + // const int32_t field_size_bytes = + // sizeof(ndim) + sizeof(has_null_flags) + sizeof(child_oid_) + sizeof(dim) * ndim + // + sizeof(lb) * ndim + // + sizeof(int32_t) * dim + T * dim; + // + // directly to our buffer + nanoarrow::UniqueBuffer tmp; + ArrowBufferInit(tmp.get()); + for (auto i = start; i < end; ++i) { + NANOARROW_RETURN_NOT_OK(child_->Write(tmp.get(), i, error)); + } + const int32_t field_size_bytes = sizeof(ndim) + sizeof(has_null_flags) + + sizeof(child_oid_) + sizeof(dim) * ndim + + sizeof(lb) * ndim + tmp->size_bytes; + + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, field_size_bytes, error)); + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, ndim, error)); + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, has_null_flags, error)); + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, child_oid_, error)); + for (int32_t i = 0; i < ndim; ++i) { + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, dim, error)); + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, lb, error)); + } + + ArrowBufferAppend(buffer, tmp->data, tmp->size_bytes); + + return ADBC_STATUS_OK; + } + + private: + const uint32_t child_oid_; + std::unique_ptr child_; +}; + template class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter { public: @@ -495,99 +566,141 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter { }; static inline ArrowErrorCode MakeCopyFieldWriter( - struct ArrowSchema* schema, std::unique_ptr* out, - ArrowError* error) { + struct ArrowSchema* schema, struct ArrowArrayView* array_view, + const PostgresTypeResolver& type_resolver, + std::unique_ptr* out, ArrowError* error) { struct ArrowSchemaView schema_view; NANOARROW_RETURN_NOT_OK(ArrowSchemaViewInit(&schema_view, schema, error)); switch (schema_view.type) { case NANOARROW_TYPE_BOOL: - *out = std::make_unique(); + using T = PostgresCopyBooleanFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; case NANOARROW_TYPE_INT8: case NANOARROW_TYPE_INT16: - *out = std::make_unique>(); + case NANOARROW_TYPE_UINT8: { + using T = PostgresCopyNetworkEndianFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_INT32: - *out = std::make_unique>(); + case NANOARROW_TYPE_UINT16: { + using T = PostgresCopyNetworkEndianFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } + case NANOARROW_TYPE_UINT32: case NANOARROW_TYPE_INT64: - *out = std::make_unique>(); + case NANOARROW_TYPE_UINT64: { + using T = PostgresCopyNetworkEndianFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_DATE32: { constexpr int32_t kPostgresDateEpoch = 10957; - *out = std::make_unique< - PostgresCopyNetworkEndianFieldWriter>(); + using T = PostgresCopyNetworkEndianFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; } - case NANOARROW_TYPE_FLOAT: - *out = std::make_unique(); + case NANOARROW_TYPE_TIME64: { + switch (schema_view.time_unit) { + case NANOARROW_TIME_UNIT_MICRO: + using T = PostgresCopyNetworkEndianFieldWriter; + *out = T::Create(array_view); + return NANOARROW_OK; + default: + return ADBC_STATUS_NOT_IMPLEMENTED; + } + } + case NANOARROW_TYPE_HALF_FLOAT: + case NANOARROW_TYPE_FLOAT: { + using T = PostgresCopyFloatFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; - case NANOARROW_TYPE_DOUBLE: - *out = std::make_unique(); + } + case NANOARROW_TYPE_DOUBLE: { + using T = PostgresCopyDoubleFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_DECIMAL128: { + using T = PostgresCopyNumericFieldWriter; const auto precision = schema_view.decimal_precision; const auto scale = schema_view.decimal_scale; - *out = std::make_unique>( - precision, scale); + *out = T::Create(array_view, precision, scale); return NANOARROW_OK; } case NANOARROW_TYPE_DECIMAL256: { + using T = PostgresCopyNumericFieldWriter; const auto precision = schema_view.decimal_precision; const auto scale = schema_view.decimal_scale; - *out = std::make_unique>( - precision, scale); + *out = T::Create(array_view, precision, scale); return NANOARROW_OK; } case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: - *out = std::make_unique(); + case NANOARROW_TYPE_STRING_VIEW: { + using T = PostgresCopyBinaryFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_TIMESTAMP: { switch (schema_view.time_unit) { - case NANOARROW_TIME_UNIT_NANO: - *out = std::make_unique< - PostgresCopyTimestampFieldWriter>(); + case NANOARROW_TIME_UNIT_NANO: { + using T = PostgresCopyTimestampFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_MILLI: - *out = std::make_unique< - PostgresCopyTimestampFieldWriter>(); + } + case NANOARROW_TIME_UNIT_MILLI: { + using T = PostgresCopyTimestampFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_MICRO: - *out = std::make_unique< - PostgresCopyTimestampFieldWriter>(); + } + case NANOARROW_TIME_UNIT_MICRO: { + using T = PostgresCopyTimestampFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_SECOND: - *out = std::make_unique< - PostgresCopyTimestampFieldWriter>(); + } + case NANOARROW_TIME_UNIT_SECOND: { + using T = PostgresCopyTimestampFieldWriter; + *out = T::Create(array_view); break; + } } return NANOARROW_OK; } - case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - *out = std::make_unique(); + case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { + using T = PostgresCopyIntervalFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } case NANOARROW_TYPE_DURATION: { switch (schema_view.time_unit) { - case NANOARROW_TIME_UNIT_SECOND: - *out = std::make_unique< - PostgresCopyDurationFieldWriter>(); + case NANOARROW_TIME_UNIT_SECOND: { + using T = PostgresCopyDurationFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_MILLI: - *out = std::make_unique< - PostgresCopyDurationFieldWriter>(); + } + case NANOARROW_TIME_UNIT_MILLI: { + using T = PostgresCopyDurationFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_MICRO: - *out = std::make_unique< - PostgresCopyDurationFieldWriter>(); - + } + case NANOARROW_TIME_UNIT_MICRO: { + using T = PostgresCopyDurationFieldWriter; + *out = T::Create(array_view); break; - case NANOARROW_TIME_UNIT_NANO: - *out = std::make_unique< - PostgresCopyDurationFieldWriter>(); + } + case NANOARROW_TIME_UNIT_NANO: { + using T = PostgresCopyDurationFieldWriter; + *out = T::Create(array_view); break; + } } return NANOARROW_OK; } @@ -599,12 +712,41 @@ static inline ArrowErrorCode MakeCopyFieldWriter( case NANOARROW_TYPE_BINARY: case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_BINARY: - case NANOARROW_TYPE_LARGE_STRING: - *out = std::make_unique(); + case NANOARROW_TYPE_LARGE_STRING: { + using T = PostgresCopyBinaryDictFieldWriter; + *out = T::Create(array_view); return NANOARROW_OK; + } default: break; } + break; + } + case NANOARROW_TYPE_LIST: + case NANOARROW_TYPE_LARGE_LIST: + case NANOARROW_TYPE_FIXED_SIZE_LIST: { + // For now our implementation only supports primitive children types + // See PostgresCopyListFieldWriter::Write for limtiations + struct ArrowSchemaView child_schema_view; + NANOARROW_RETURN_NOT_OK( + ArrowSchemaViewInit(&child_schema_view, schema->children[0], error)); + PostgresType child_type; + NANOARROW_RETURN_NOT_OK(PostgresType::FromSchema(type_resolver, schema->children[0], + &child_type, error)); + + std::unique_ptr child_writer; + NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(schema->children[0], + array_view->children[0], type_resolver, + &child_writer, error)); + + if (schema_view.type == NANOARROW_TYPE_FIXED_SIZE_LIST) { + using T = PostgresCopyListFieldWriter; + *out = T::Create(array_view, child_type.oid(), std::move(child_writer)); + } else { + using T = PostgresCopyListFieldWriter; + *out = T::Create(array_view, child_type.oid(), std::move(child_writer)); + } + return NANOARROW_OK; } default: break; @@ -620,7 +762,8 @@ class PostgresCopyStreamWriter { schema_ = schema; NANOARROW_RETURN_NOT_OK( ArrowArrayViewInitFromSchema(&array_view_.value, schema, nullptr)); - root_writer_.Init(&array_view_.value); + root_writer_ = PostgresCopyFieldTupleWriter::Create( + &array_view_.value); ArrowBufferInit(&buffer_.value); return NANOARROW_OK; } @@ -646,21 +789,23 @@ class PostgresCopyStreamWriter { } ArrowErrorCode WriteRecord(ArrowError* error) { - NANOARROW_RETURN_NOT_OK(root_writer_.Write(&buffer_.value, records_written_, error)); + NANOARROW_RETURN_NOT_OK(root_writer_->Write(&buffer_.value, records_written_, error)); records_written_++; return NANOARROW_OK; } - ArrowErrorCode InitFieldWriters(ArrowError* error) { + ArrowErrorCode InitFieldWriters(const PostgresTypeResolver& type_resolver, + ArrowError* error) { if (schema_->release == nullptr) { return EINVAL; } for (int64_t i = 0; i < schema_->n_children; i++) { std::unique_ptr child_writer; - NANOARROW_RETURN_NOT_OK( - MakeCopyFieldWriter(schema_->children[i], &child_writer, error)); - root_writer_.AppendChild(std::move(child_writer)); + NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(schema_->children[i], + array_view_->children[i], type_resolver, + &child_writer, error)); + root_writer_->AppendChild(std::move(child_writer)); } return NANOARROW_OK; @@ -674,7 +819,7 @@ class PostgresCopyStreamWriter { } private: - PostgresCopyFieldTupleWriter root_writer_; + std::unique_ptr root_writer_; struct ArrowSchema* schema_; Handle array_view_; Handle buffer_; diff --git a/c/driver/postgresql/database.cc b/c/driver/postgresql/database.cc index 2837e82338..1bd7444884 100644 --- a/c/driver/postgresql/database.cc +++ b/c/driver/postgresql/database.cc @@ -17,17 +17,20 @@ #include "database.h" +#include +#include #include #include #include #include #include -#include +#include #include #include #include "driver/common/utils.h" +#include "result_helper.h" namespace adbcpq { @@ -54,8 +57,19 @@ AdbcStatusCode PostgresDatabase::GetOptionDouble(const char* option, double* val } AdbcStatusCode PostgresDatabase::Init(struct AdbcError* error) { - // Connect to validate the parameters. - return RebuildTypeResolver(error); + // Connect to initialize the version information and build the type table + PGconn* conn = nullptr; + RAISE_ADBC(Connect(&conn, error)); + + Status status = InitVersions(conn); + if (!status.ok()) { + RAISE_ADBC(Disconnect(&conn, nullptr)); + return status.ToAdbc(error); + } + + status = RebuildTypeResolver(conn); + RAISE_ADBC(Disconnect(&conn, nullptr)); + return status.ToAdbc(error); } AdbcStatusCode PostgresDatabase::Release(struct AdbcError* error) { @@ -123,20 +137,87 @@ AdbcStatusCode PostgresDatabase::Disconnect(PGconn** conn, struct AdbcError* err return ADBC_STATUS_OK; } -// Helpers for building the type resolver from queries -static inline int32_t InsertPgAttributeResult( - PGresult* result, const std::shared_ptr& resolver); +namespace { + +// Parse an individual version in the form of "xxx.xxx.xxx". +// If the version components aren't numeric, they will be zero. +std::array ParseVersion(std::string_view version) { + std::array out{}; + size_t component = 0; + size_t component_begin = 0; + size_t component_end = 0; + + // While there are remaining version components and we haven't reached the end of the + // string + while (component_begin < version.size() && component < out.size()) { + // Find the next character that marks a version component separation or the end of the + // string + component_end = version.find_first_of(".-", component_begin); + if (component_end == version.npos) { + component_end = version.size(); + } -static inline int32_t InsertPgTypeResult( - PGresult* result, const std::shared_ptr& resolver); + // Try to parse the component as an integer (assigning zero if this fails) + int value = 0; + std::from_chars(version.data() + component_begin, version.data() + component_end, + value); + out[component] = value; -AdbcStatusCode PostgresDatabase::RebuildTypeResolver(struct AdbcError* error) { - PGconn* conn = nullptr; - AdbcStatusCode final_status = Connect(&conn, error); - if (final_status != ADBC_STATUS_OK) { - return final_status; + // Move on to the next component + component_begin = component_end + 1; + component_end = component_begin; + component++; + } + + return out; +} + +// Parse the PostgreSQL version() string that looks like: +// PostgreSQL 8.0.2 on i686-pc-linux-gnu, compiled by GCC gcc (GCC) 3.4.2 20041017 (Red +// Hat 3.4.2-6.fc3), Redshift 1.0.77467 +std::array ParsePrefixedVersion(std::string_view version_info, + std::string_view prefix) { + size_t pos = version_info.find(prefix); + if (pos == version_info.npos) { + return {0, 0, 0}; } + // Skip the prefix and any leading whitespace + pos = version_info.find_first_not_of(' ', pos + prefix.size()); + if (pos == version_info.npos) { + return {0, 0, 0}; + } + + return ParseVersion(version_info.substr(pos)); +} + +} // namespace + +Status PostgresDatabase::InitVersions(PGconn* conn) { + PqResultHelper helper(conn, "SELECT version();"); + UNWRAP_STATUS(helper.Execute()); + if (helper.NumRows() != 1 || helper.NumColumns() != 1) { + return Status::Internal("Expected 1 row and 1 column for SELECT version(); but got ", + helper.NumRows(), "/", helper.NumColumns()); + } + + std::string_view version_info = helper.Row(0)[0].value(); + postgres_server_version_ = ParsePrefixedVersion(version_info, "PostgreSQL"); + redshift_server_version_ = ParsePrefixedVersion(version_info, "Redshift"); + + return Status::Ok(); +} + +// Helpers for building the type resolver from queries +static std::string BuildPgTypeQuery(bool has_typarray); + +static Status InsertPgAttributeResult( + const PqResultHelper& result, const std::shared_ptr& resolver); + +static Status InsertPgTypeResult(const PqResultHelper& result, + const std::shared_ptr& resolver); + +Status PostgresDatabase::RebuildTypeResolver(PGconn* conn) { // We need a few queries to build the resolver. The current strategy might // fail for some recursive definitions (e.g., arrays of records of arrays). // First, one on the pg_attribute table to resolve column names/oids for @@ -156,147 +237,131 @@ ORDER BY // recursive definitions (e.g., record types with array column). This currently won't // handle range types because those rows don't have child OID information. Arrays types // are inserted after a successful insert of the element type. - const std::string kTypeQuery = R"( -SELECT - oid, - typname, - typreceive, - typbasetype, - typarray, - typrelid -FROM - pg_catalog.pg_type -WHERE - (typreceive != 0 OR typname = 'aclitem') AND typtype != 'r' AND typreceive::TEXT != 'array_recv' -ORDER BY - oid -)"; + std::string type_query = + BuildPgTypeQuery(/*has_typarray*/ redshift_server_version_[0] == 0); // Create a new type resolver (this instance's type_resolver_ member // will be updated at the end if this succeeds). auto resolver = std::make_shared(); // Insert record type definitions (this includes table schemas) - PGresult* result = PQexec(conn, kColumnsQuery.c_str()); - ExecStatusType pq_status = PQresultStatus(result); - if (pq_status == PGRES_TUPLES_OK) { - InsertPgAttributeResult(result, resolver); - } else { - SetError(error, "%s%s", - "[libpq] Failed to build type mapping table: ", PQerrorMessage(conn)); - final_status = ADBC_STATUS_IO; - } - - PQclear(result); + PqResultHelper columns(conn, kColumnsQuery.c_str()); + UNWRAP_STATUS(columns.Execute()); + UNWRAP_STATUS(InsertPgAttributeResult(columns, resolver)); // Attempt filling the resolver a few times to handle recursive definitions. int32_t max_attempts = 3; + PqResultHelper types(conn, type_query); for (int32_t i = 0; i < max_attempts; i++) { - result = PQexec(conn, kTypeQuery.c_str()); - ExecStatusType pq_status = PQresultStatus(result); - if (pq_status == PGRES_TUPLES_OK) { - InsertPgTypeResult(result, resolver); - } else { - SetError(error, "%s%s", - "[libpq] Failed to build type mapping table: ", PQerrorMessage(conn)); - final_status = ADBC_STATUS_IO; - } - - PQclear(result); - if (final_status != ADBC_STATUS_OK) { - break; - } + UNWRAP_STATUS(types.Execute()); + UNWRAP_STATUS(InsertPgTypeResult(types, resolver)); } - // Disconnect since PostgreSQL connections can be heavy. - { - AdbcStatusCode status = Disconnect(&conn, error); - if (status != ADBC_STATUS_OK) final_status = status; - } + type_resolver_ = std::move(resolver); + return Status::Ok(); +} - if (final_status == ADBC_STATUS_OK) { - type_resolver_ = std::move(resolver); +static std::string BuildPgTypeQuery(bool has_typarray) { + std::string maybe_typarray_col; + std::string maybe_array_recv_filter; + if (has_typarray) { + maybe_typarray_col = ", typarray"; + maybe_array_recv_filter = "AND typreceive::TEXT != 'array_recv'"; } - return final_status; + return std::string() + "SELECT oid, typname, typreceive, typbasetype, typrelid" + + maybe_typarray_col + " FROM pg_catalog.pg_type " + + " WHERE (typreceive != 0 OR typsend != 0) AND typtype != 'r' " + + maybe_array_recv_filter; } -static inline int32_t InsertPgAttributeResult( - PGresult* result, const std::shared_ptr& resolver) { - int num_rows = PQntuples(result); +static Status InsertPgAttributeResult( + const PqResultHelper& result, const std::shared_ptr& resolver) { + int num_rows = result.NumRows(); std::vector> columns; - uint32_t current_type_oid = 0; - int32_t n_added = 0; + int64_t current_type_oid = 0; + + if (result.NumColumns() != 3) { + return Status::Internal( + "Expected 3 columns from type resolver pg_attribute query but got ", + result.NumColumns()); + } for (int row = 0; row < num_rows; row++) { - const uint32_t type_oid = static_cast( - std::strtol(PQgetvalue(result, row, 0), /*str_end=*/nullptr, /*base=*/10)); - const char* col_name = PQgetvalue(result, row, 1); - const uint32_t col_oid = static_cast( - std::strtol(PQgetvalue(result, row, 2), /*str_end=*/nullptr, /*base=*/10)); + PqResultRow item = result.Row(row); + UNWRAP_RESULT(int64_t type_oid, item[0].ParseInteger()); + std::string_view col_name = item[1].value(); + UNWRAP_RESULT(int64_t col_oid, item[2].ParseInteger()); if (type_oid != current_type_oid && !columns.empty()) { - resolver->InsertClass(current_type_oid, columns); + resolver->InsertClass(static_cast(current_type_oid), columns); columns.clear(); current_type_oid = type_oid; - n_added++; } - columns.push_back({col_name, col_oid}); + columns.push_back({std::string(col_name), static_cast(col_oid)}); } if (!columns.empty()) { - resolver->InsertClass(current_type_oid, columns); - n_added++; + resolver->InsertClass(static_cast(current_type_oid), columns); } - return n_added; + return Status::Ok(); } -static inline int32_t InsertPgTypeResult( - PGresult* result, const std::shared_ptr& resolver) { - int num_rows = PQntuples(result); - PostgresTypeResolver::Item item; - int32_t n_added = 0; +static Status InsertPgTypeResult(const PqResultHelper& result, + const std::shared_ptr& resolver) { + if (result.NumColumns() != 5 && result.NumColumns() != 6) { + return Status::Internal( + "Expected 5 or 6 columns from type resolver pg_type query but got ", + result.NumColumns()); + } + + int num_rows = result.NumRows(); + int num_cols = result.NumColumns(); + PostgresTypeResolver::Item type_item; for (int row = 0; row < num_rows; row++) { - const uint32_t oid = static_cast( - std::strtol(PQgetvalue(result, row, 0), /*str_end=*/nullptr, /*base=*/10)); - const char* typname = PQgetvalue(result, row, 1); - const char* typreceive = PQgetvalue(result, row, 2); - const uint32_t typbasetype = static_cast( - std::strtol(PQgetvalue(result, row, 3), /*str_end=*/nullptr, /*base=*/10)); - const uint32_t typarray = static_cast( - std::strtol(PQgetvalue(result, row, 4), /*str_end=*/nullptr, /*base=*/10)); - const uint32_t typrelid = static_cast( - std::strtol(PQgetvalue(result, row, 5), /*str_end=*/nullptr, /*base=*/10)); + PqResultRow item = result.Row(row); + UNWRAP_RESULT(int64_t oid, item[0].ParseInteger()); + const char* typname = item[1].data; + const char* typreceive = item[2].data; + UNWRAP_RESULT(int64_t typbasetype, item[3].ParseInteger()); + UNWRAP_RESULT(int64_t typrelid, item[4].ParseInteger()); + + int64_t typarray; + if (num_cols == 6) { + UNWRAP_RESULT(typarray, item[5].ParseInteger()); + } else { + typarray = 0; + } // Special case the aclitem because it shows up in a bunch of internal tables if (strcmp(typname, "aclitem") == 0) { typreceive = "aclitem_recv"; } - item.oid = oid; - item.typname = typname; - item.typreceive = typreceive; - item.class_oid = typrelid; - item.base_oid = typbasetype; + type_item.oid = static_cast(oid); + type_item.typname = typname; + type_item.typreceive = typreceive; + type_item.class_oid = static_cast(typrelid); + type_item.base_oid = static_cast(typbasetype); - int result = resolver->Insert(item, nullptr); + int insert_result = resolver->Insert(type_item, nullptr); // If there's an array type and the insert succeeded, add that now too - if (result == NANOARROW_OK && typarray != 0) { + if (insert_result == NANOARROW_OK && typarray != 0) { std::string array_typname = "_" + std::string(typname); - item.oid = typarray; - item.typname = array_typname.c_str(); - item.typreceive = "array_recv"; - item.child_oid = oid; + type_item.oid = static_cast(typarray); + type_item.typname = array_typname.c_str(); + type_item.typreceive = "array_recv"; + type_item.child_oid = static_cast(oid); - resolver->Insert(item, nullptr); + resolver->Insert(type_item, nullptr); } } - return n_added; + return Status::Ok(); } } // namespace adbcpq diff --git a/c/driver/postgresql/database.h b/c/driver/postgresql/database.h index 6c3da58daa..e0a00267e3 100644 --- a/c/driver/postgresql/database.h +++ b/c/driver/postgresql/database.h @@ -17,16 +17,20 @@ #pragma once +#include #include #include #include -#include +#include #include +#include "driver/framework/status.h" #include "postgres_type.h" namespace adbcpq { +using adbc::driver::Status; + class PostgresDatabase { public: PostgresDatabase(); @@ -58,12 +62,29 @@ class PostgresDatabase { return type_resolver_; } - AdbcStatusCode RebuildTypeResolver(struct AdbcError* error); + Status InitVersions(PGconn* conn); + Status RebuildTypeResolver(PGconn* conn); + std::string_view VendorName() { + if (redshift_server_version_[0] != 0) { + return "Redshift"; + } else { + return "PostgreSQL"; + } + } + const std::array& VendorVersion() { + if (redshift_server_version_[0] != 0) { + return redshift_server_version_; + } else { + return postgres_server_version_; + } + } private: int32_t open_connections_; std::string uri_; std::shared_ptr type_resolver_; + std::array postgres_server_version_{}; + std::array redshift_server_version_{}; }; } // namespace adbcpq diff --git a/c/driver/postgresql/error.cc b/c/driver/postgresql/error.cc index 276aadc1ce..173868baf5 100644 --- a/c/driver/postgresql/error.cc +++ b/c/driver/postgresql/error.cc @@ -17,8 +17,8 @@ #include "error.h" -#include #include +#include #include #include #include @@ -29,72 +29,27 @@ namespace adbcpq { -namespace { -struct DetailField { - int code; - std::string key; -}; - -static const std::vector kDetailFields = { - {PG_DIAG_COLUMN_NAME, "PG_DIAG_COLUMN_NAME"}, - {PG_DIAG_CONTEXT, "PG_DIAG_CONTEXT"}, - {PG_DIAG_CONSTRAINT_NAME, "PG_DIAG_CONSTRAINT_NAME"}, - {PG_DIAG_DATATYPE_NAME, "PG_DIAG_DATATYPE_NAME"}, - {PG_DIAG_INTERNAL_POSITION, "PG_DIAG_INTERNAL_POSITION"}, - {PG_DIAG_INTERNAL_QUERY, "PG_DIAG_INTERNAL_QUERY"}, - {PG_DIAG_MESSAGE_PRIMARY, "PG_DIAG_MESSAGE_PRIMARY"}, - {PG_DIAG_MESSAGE_DETAIL, "PG_DIAG_MESSAGE_DETAIL"}, - {PG_DIAG_MESSAGE_HINT, "PG_DIAG_MESSAGE_HINT"}, - {PG_DIAG_SEVERITY_NONLOCALIZED, "PG_DIAG_SEVERITY_NONLOCALIZED"}, - {PG_DIAG_SQLSTATE, "PG_DIAG_SQLSTATE"}, - {PG_DIAG_STATEMENT_POSITION, "PG_DIAG_STATEMENT_POSITION"}, - {PG_DIAG_SCHEMA_NAME, "PG_DIAG_SCHEMA_NAME"}, - {PG_DIAG_TABLE_NAME, "PG_DIAG_TABLE_NAME"}, -}; -} // namespace - AdbcStatusCode SetError(struct AdbcError* error, PGresult* result, const char* format, ...) { + if (error && error->release) { + // TODO: combine the errors if possible + error->release(error); + } + va_list args; va_start(args, format); - SetErrorVariadic(error, format, args); + std::string message; + message.resize(1024); + int chars_needed = vsnprintf(message.data(), message.size(), format, args); va_end(args); - AdbcStatusCode code = ADBC_STATUS_IO; - - const char* sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE); - if (sqlstate) { - // https://www.postgresql.org/docs/current/errcodes-appendix.html - // This can be extended in the future - if (std::strcmp(sqlstate, "57014") == 0) { - code = ADBC_STATUS_CANCELLED; - } else if (std::strcmp(sqlstate, "42P01") == 0 || - std::strcmp(sqlstate, "42602") == 0) { - code = ADBC_STATUS_NOT_FOUND; - } else if (std::strncmp(sqlstate, "42", 0) == 0) { - // Class 42 — Syntax Error or Access Rule Violation - code = ADBC_STATUS_INVALID_ARGUMENT; - } - - static_assert(sizeof(error->sqlstate) == 5, ""); - // N.B. strncpy generates warnings when used for this purpose - int i = 0; - for (; sqlstate[i] != '\0' && i < 5; i++) { - error->sqlstate[i] = sqlstate[i]; - } - for (; i < 5; i++) { - error->sqlstate[i] = '\0'; - } + if (chars_needed > 0) { + message.resize(chars_needed); + } else { + message.resize(0); } - for (const auto& field : kDetailFields) { - const char* value = PQresultErrorField(result, field.code); - if (value) { - AppendErrorDetail(error, field.key.c_str(), reinterpret_cast(value), - std::strlen(value)); - } - } - return code; + return MakeStatus(result, "{}", message).ToAdbc(error); } } // namespace adbcpq diff --git a/c/driver/postgresql/error.h b/c/driver/postgresql/error.h index 75c52b46c3..f24d41754b 100644 --- a/c/driver/postgresql/error.h +++ b/c/driver/postgresql/error.h @@ -19,11 +19,42 @@ #pragma once -#include +#include +#include + +#include #include +#include + +#include "driver/framework/status.h" + +using adbc::driver::Status; + namespace adbcpq { +struct DetailField { + int code; + std::string key; +}; + +static const std::vector kDetailFields = { + {PG_DIAG_COLUMN_NAME, "PG_DIAG_COLUMN_NAME"}, + {PG_DIAG_CONTEXT, "PG_DIAG_CONTEXT"}, + {PG_DIAG_CONSTRAINT_NAME, "PG_DIAG_CONSTRAINT_NAME"}, + {PG_DIAG_DATATYPE_NAME, "PG_DIAG_DATATYPE_NAME"}, + {PG_DIAG_INTERNAL_POSITION, "PG_DIAG_INTERNAL_POSITION"}, + {PG_DIAG_INTERNAL_QUERY, "PG_DIAG_INTERNAL_QUERY"}, + {PG_DIAG_MESSAGE_PRIMARY, "PG_DIAG_MESSAGE_PRIMARY"}, + {PG_DIAG_MESSAGE_DETAIL, "PG_DIAG_MESSAGE_DETAIL"}, + {PG_DIAG_MESSAGE_HINT, "PG_DIAG_MESSAGE_HINT"}, + {PG_DIAG_SEVERITY_NONLOCALIZED, "PG_DIAG_SEVERITY_NONLOCALIZED"}, + {PG_DIAG_SQLSTATE, "PG_DIAG_SQLSTATE"}, + {PG_DIAG_STATEMENT_POSITION, "PG_DIAG_STATEMENT_POSITION"}, + {PG_DIAG_SCHEMA_NAME, "PG_DIAG_SCHEMA_NAME"}, + {PG_DIAG_TABLE_NAME, "PG_DIAG_TABLE_NAME"}, +}; + // The printf checking attribute doesn't work properly on gcc 4.8 // and results in spurious compiler warnings #if defined(__clang__) || (defined(__GNUC__) && __GNUC__ >= 5) @@ -33,10 +64,50 @@ namespace adbcpq { #endif /// \brief Set an error based on a PGresult, inferring the proper ADBC status -/// code from the PGresult. +/// code from the PGresult. Deprecated and is currently a thin wrapper around +/// MakeStatus() below. AdbcStatusCode SetError(struct AdbcError* error, PGresult* result, const char* format, ...) ADBC_CHECK_PRINTF_ATTRIBUTE(3, 4); #undef ADBC_CHECK_PRINTF_ATTRIBUTE +template +Status MakeStatus(PGresult* result, const char* format_string, Args&&... args) { + auto message = ::fmt::vformat(format_string, ::fmt::make_format_args(args...)); + + AdbcStatusCode code = ADBC_STATUS_IO; + char sqlstate_out[5]; + std::memset(sqlstate_out, 0, sizeof(sqlstate_out)); + + if (result == nullptr) { + return Status(code, message); + } + + const char* sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE); + if (sqlstate) { + // https://www.postgresql.org/docs/current/errcodes-appendix.html + // This can be extended in the future + if (std::strcmp(sqlstate, "57014") == 0) { + code = ADBC_STATUS_CANCELLED; + } else if (std::strcmp(sqlstate, "42P01") == 0 || + std::strcmp(sqlstate, "42602") == 0) { + code = ADBC_STATUS_NOT_FOUND; + } else if (std::strncmp(sqlstate, "42", 0) == 0) { + // Class 42 — Syntax Error or Access Rule Violation + code = ADBC_STATUS_INVALID_ARGUMENT; + } + } + + Status status(code, message); + status.SetSqlState(sqlstate); + for (const auto& field : kDetailFields) { + const char* value = PQresultErrorField(result, field.code); + if (value) { + status.AddDetail(field.key, value); + } + } + + return status; +} + } // namespace adbcpq diff --git a/c/driver/postgresql/meson.build b/c/driver/postgresql/meson.build new file mode 100644 index 0000000000..ac075417f5 --- /dev/null +++ b/c/driver/postgresql/meson.build @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +libpq_dep = dependency('libpq') + +adbc_postgres_driver_lib = library( + 'adbc_driver_postgresql', + sources: [ + 'connection.cc', + 'error.cc', + 'database.cc', + 'postgresql.cc', + 'result_helper.cc', + 'result_reader.cc', + 'statement.cc', + ], + include_directories: [include_dir, c_dir], + link_with: [adbc_common_lib, adbc_framework_lib], + dependencies: [nanoarrow_dep, fmt_dep, libpq_dep], +) + +pkg.generate( + name: 'Apache Arrow Database Connectivity (ADBC) PostgreSQL driver', + description: 'The ADBC PostgreSQL driver provides an ADBC driver for PostgreSQL.', + url: 'https://github.com/apache/arrow-adbc', + libraries: [adbc_postgres_driver_lib], + filebase: 'adbc-driver-postgresql', +) + +if get_option('tests') + postgres_tests = { + 'driver-postgresql': { + 'src_name': 'driver_postgresql', + 'sources': [ + 'postgres_type_test.cc', + 'postgresql_test.cc', + ] + }, + 'driver-postgresql-copy': { + 'src_name': 'driver_postgresql_copy', + 'sources': [ + 'copy/postgres_copy_reader_test.cc', + 'copy/postgres_copy_writer_test.cc', + ] + }, + } + + foreach name, conf : postgres_tests + exc = executable( + 'adbc-' + name + '-test', + sources: conf['sources'], + include_directories: [include_dir, driver_dir, c_dir], + link_with: [ + adbc_common_lib, + adbc_postgres_driver_lib, + ], + dependencies: [libpq_dep, adbc_validation_dep], + ) + test('adbc-' + name, exc) + endforeach +endif diff --git a/c/driver/postgresql/postgres_type.h b/c/driver/postgresql/postgres_type.h index c7cc55745a..d2a5356293 100644 --- a/c/driver/postgresql/postgres_type.h +++ b/c/driver/postgresql/postgres_type.h @@ -111,7 +111,11 @@ enum class PostgresTypeId { kXid8, kXid, kXml, - kUserDefined + kUserDefined, + // This is not an actual type, but there are cases where all we have is an Oid + // that was not inserted into the type resolver. We can't use "unknown" or "opaque" + // or "void" because those names show up in actual pg_type tables. + kUnnamedArrowOpaque }; // Returns the receive function name as defined in the typrecieve column @@ -139,6 +143,11 @@ class PostgresType { PostgresType() : PostgresType(PostgresTypeId::kUninitialized) {} + static PostgresType Unnamed(uint32_t oid) { + return PostgresType(PostgresTypeId::kUnnamedArrowOpaque) + .WithPgTypeInfo(oid, "unnamed"); + } + void AppendChild(const std::string& field_name, const PostgresType& type) { PostgresType child(type); children_.push_back(child.WithFieldName(field_name)); @@ -184,6 +193,19 @@ class PostgresType { int64_t n_children() const { return static_cast(children_.size()); } const PostgresType& child(int64_t i) const { return children_[i]; } + // The name used to communicate this type in a CREATE TABLE statement. + // These are not necessarily the most idiomatic names to use but PostgreSQL + // will accept typname() according to the "aliases" column in + // https://www.postgresql.org/docs/current/datatype.html + const std::string sql_type_name() const { + switch (type_id_) { + case PostgresTypeId::kArray: + return children_[0].sql_type_name() + " ARRAY"; + default: + return typname_; + } + } + // Sets appropriate fields of an ArrowSchema that has been initialized using // ArrowSchemaInit. This is a recursive operation (i.e., nested types will // initialize and set the appropriate number of children). Returns NANOARROW_OK @@ -191,7 +213,8 @@ class PostgresType { // do not have a corresponding Arrow type are returned as Binary with field // metadata ADBC:posgresql:typname. These types can be represented as their // binary COPY representation in the output. - ArrowErrorCode SetSchema(ArrowSchema* schema) const { + ArrowErrorCode SetSchema(ArrowSchema* schema, + const std::string& vendor_name = "PostgreSQL") const { switch (type_id_) { // ---- Primitive types -------------------- case PostgresTypeId::kBool: @@ -222,7 +245,7 @@ class PostgresType { // ---- Numeric/Decimal------------------- case PostgresTypeId::kNumeric: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_STRING)); - NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema)); + NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema, vendor_name)); break; @@ -233,6 +256,8 @@ class PostgresType { case PostgresTypeId::kText: case PostgresTypeId::kName: case PostgresTypeId::kEnum: + case PostgresTypeId::kJson: + case PostgresTypeId::kJsonb: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_STRING)); break; case PostgresTypeId::kBytea: @@ -275,13 +300,14 @@ class PostgresType { case PostgresTypeId::kRecord: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeStruct(schema, n_children())); for (int64_t i = 0; i < n_children(); i++) { - NANOARROW_RETURN_NOT_OK(children_[i].SetSchema(schema->children[i])); + NANOARROW_RETURN_NOT_OK( + children_[i].SetSchema(schema->children[i], vendor_name)); } break; case PostgresTypeId::kArray: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_LIST)); - NANOARROW_RETURN_NOT_OK(children_[0].SetSchema(schema->children[0])); + NANOARROW_RETURN_NOT_OK(children_[0].SetSchema(schema->children[0], vendor_name)); break; case PostgresTypeId::kUserDefined: @@ -290,7 +316,7 @@ class PostgresType { // can still return the bytes postgres gives us and attach the type name as // metadata NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_BINARY)); - NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema)); + NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema, vendor_name)); break; } @@ -310,8 +336,12 @@ class PostgresType { std::vector children_; static constexpr const char* kPostgresTypeKey = "ADBC:postgresql:typname"; + static constexpr const char* kExtensionName = "ARROW:extension:name"; + static constexpr const char* kOpaqueExtensionName = "arrow.opaque"; + static constexpr const char* kExtensionMetadata = "ARROW:extension:metadata"; - ArrowErrorCode AddPostgresTypeMetadata(ArrowSchema* schema) const { + ArrowErrorCode AddPostgresTypeMetadata(ArrowSchema* schema, + const std::string& vendor_name) const { // the typname_ may not always be set: an instance of this class can be // created with just the type id. That's why there is this here fallback to // resolve the type name of built-in types. @@ -320,8 +350,25 @@ class PostgresType { nanoarrow::UniqueBuffer buffer; ArrowMetadataBuilderInit(buffer.get(), nullptr); + // TODO(lidavidm): we have deprecated this in favor of arrow.opaque, + // remove once we feel enough time has passed NANOARROW_RETURN_NOT_OK(ArrowMetadataBuilderAppend( buffer.get(), ArrowCharView(kPostgresTypeKey), ArrowCharView(typname))); + + // Add the Opaque extension type metadata + std::string metadata = R"({"type_name": ")"; + metadata += typname; + metadata += R"(", "vendor_name": ")" + vendor_name + R"("})"; + NANOARROW_RETURN_NOT_OK( + ArrowMetadataBuilderAppend(buffer.get(), ArrowCharView(kExtensionName), + ArrowCharView(kOpaqueExtensionName))); + NANOARROW_RETURN_NOT_OK( + ArrowMetadataBuilderAppend(buffer.get(), ArrowCharView(kExtensionMetadata), + ArrowStringView{ + metadata.c_str(), + static_cast(metadata.size()), + })); + NANOARROW_RETURN_NOT_OK( ArrowSchemaSetMetadata(schema, reinterpret_cast(buffer->data))); @@ -360,7 +407,18 @@ class PostgresTypeResolver { return EINVAL; } - *type_out = (*result).second; + *type_out = result->second; + return NANOARROW_OK; + } + + ArrowErrorCode FindWithDefault(uint32_t oid, PostgresType* type_out) { + auto result = mapping_.find(oid); + if (result == mapping_.end()) { + *type_out = PostgresType::Unnamed(oid); + } else { + *type_out = result->second; + } + return NANOARROW_OK; } @@ -523,16 +581,40 @@ inline ArrowErrorCode PostgresType::FromSchema(const PostgresTypeResolver& resol return resolver.Find(resolver.GetOID(PostgresTypeId::kInt4), out, error); case NANOARROW_TYPE_UINT32: case NANOARROW_TYPE_INT64: + case NANOARROW_TYPE_UINT64: return resolver.Find(resolver.GetOID(PostgresTypeId::kInt8), out, error); + case NANOARROW_TYPE_HALF_FLOAT: case NANOARROW_TYPE_FLOAT: return resolver.Find(resolver.GetOID(PostgresTypeId::kFloat4), out, error); case NANOARROW_TYPE_DOUBLE: return resolver.Find(resolver.GetOID(PostgresTypeId::kFloat8), out, error); case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: return resolver.Find(resolver.GetOID(PostgresTypeId::kText), out, error); case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: return resolver.Find(resolver.GetOID(PostgresTypeId::kBytea), out, error); + case NANOARROW_TYPE_DATE32: + case NANOARROW_TYPE_DATE64: + return resolver.Find(resolver.GetOID(PostgresTypeId::kDate), out, error); + case NANOARROW_TYPE_TIME32: + case NANOARROW_TYPE_TIME64: + return resolver.Find(resolver.GetOID(PostgresTypeId::kTime), out, error); + case NANOARROW_TYPE_DURATION: + case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: + return resolver.Find(resolver.GetOID(PostgresTypeId::kInterval), out, error); + case NANOARROW_TYPE_TIMESTAMP: + if (strcmp("", schema_view.timezone) == 0) { + return resolver.Find(resolver.GetOID(PostgresTypeId::kTimestamptz), out, error); + } else { + return resolver.Find(resolver.GetOID(PostgresTypeId::kTimestamp), out, error); + } + case NANOARROW_TYPE_DECIMAL128: + case NANOARROW_TYPE_DECIMAL256: + return resolver.Find(resolver.GetOID(PostgresTypeId::kNumeric), out, error); case NANOARROW_TYPE_LIST: case NANOARROW_TYPE_LARGE_LIST: case NANOARROW_TYPE_FIXED_SIZE_LIST: { diff --git a/c/driver/postgresql/postgres_type_test.cc b/c/driver/postgresql/postgres_type_test.cc index 9d6152f27f..2c76f4c1f4 100644 --- a/c/driver/postgresql/postgres_type_test.cc +++ b/c/driver/postgresql/postgres_type_test.cc @@ -174,6 +174,14 @@ TEST(PostgresTypeTest, PostgresTypeSetSchema) { &typnameMetadataValue); EXPECT_EQ(std::string(typnameMetadataValue.data, typnameMetadataValue.size_bytes), "numeric"); + ArrowMetadataGetValue(schema->metadata, ArrowCharView("ARROW:extension:name"), + &typnameMetadataValue); + EXPECT_EQ(std::string(typnameMetadataValue.data, typnameMetadataValue.size_bytes), + "arrow.opaque"); + ArrowMetadataGetValue(schema->metadata, ArrowCharView("ARROW:extension:metadata"), + &typnameMetadataValue); + EXPECT_EQ(std::string(typnameMetadataValue.data, typnameMetadataValue.size_bytes), + R"({"type_name": "numeric", "vendor_name": "PostgreSQL"})"); schema.reset(); ArrowSchemaInit(schema.get()); @@ -312,11 +320,10 @@ TEST(PostgresTypeTest, PostgresTypeFromSchema) { schema.reset(); ArrowError error; - ASSERT_EQ(ArrowSchemaInitFromType(schema.get(), NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO), + ASSERT_EQ(ArrowSchemaInitFromType(schema.get(), NANOARROW_TYPE_INTERVAL_MONTHS), NANOARROW_OK); EXPECT_EQ(PostgresType::FromSchema(resolver, schema.get(), &type, &error), ENOTSUP); - EXPECT_STREQ(error.message, - "Can't map Arrow type 'interval_month_day_nano' to Postgres type"); + EXPECT_STREQ(error.message, "Can't map Arrow type 'interval_months' to Postgres type"); schema.reset(); } @@ -330,6 +337,11 @@ TEST(PostgresTypeTest, PostgresTypeResolver) { EXPECT_EQ(resolver.Find(123, &type, &error), EINVAL); EXPECT_STREQ(ArrowErrorMessage(&error), "Postgres type with oid 123 not found"); + EXPECT_EQ(resolver.FindWithDefault(123, &type), NANOARROW_OK); + EXPECT_EQ(type.oid(), 123); + EXPECT_EQ(type.type_id(), PostgresTypeId::kUnnamedArrowOpaque); + EXPECT_EQ(type.typname(), "unnamed"); + // Check error for Array with unknown child item.oid = 123; item.typname = "some_array"; diff --git a/c/driver/postgresql/postgres_util.h b/c/driver/postgresql/postgres_util.h index 6d42f85a2c..fbd609848a 100644 --- a/c/driver/postgresql/postgres_util.h +++ b/c/driver/postgresql/postgres_util.h @@ -33,7 +33,7 @@ #include #endif -#include "adbc.h" +#include "arrow-adbc/adbc.h" namespace adbcpq { diff --git a/c/driver/postgresql/postgresql.cc b/c/driver/postgresql/postgresql.cc index e1ffd543b1..e43db98879 100644 --- a/c/driver/postgresql/postgresql.cc +++ b/c/driver/postgresql/postgresql.cc @@ -20,13 +20,15 @@ #include #include -#include +#include #include "connection.h" #include "database.h" #include "driver/common/utils.h" +#include "driver/framework/status.h" #include "statement.h" +using adbc::driver::Status; using adbcpq::PostgresConnection; using adbcpq::PostgresDatabase; using adbcpq::PostgresStatement; @@ -56,14 +58,36 @@ const struct AdbcError* PostgresErrorFromArrayStream(struct ArrowArrayStream* st // Currently only valid for TupleReader return adbcpq::TupleReader::ErrorFromArrayStream(stream, status); } + +int PostgresErrorGetDetailCount(const struct AdbcError* error) { + if (IsCommonError(error)) { + return CommonErrorGetDetailCount(error); + } + + if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + return 0; + } + + auto error_obj = reinterpret_cast(error->private_data); + return error_obj->CDetailCount(); +} + +struct AdbcErrorDetail PostgresErrorGetDetail(const struct AdbcError* error, int index) { + if (IsCommonError(error)) { + return CommonErrorGetDetail(error, index); + } + + auto error_obj = reinterpret_cast(error->private_data); + return error_obj->CDetail(index); +} } // namespace int AdbcErrorGetDetailCount(const struct AdbcError* error) { - return CommonErrorGetDetailCount(error); + return PostgresErrorGetDetailCount(error); } struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { - return CommonErrorGetDetail(error, index); + return PostgresErrorGetDetail(error, index); } const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, @@ -860,8 +884,8 @@ AdbcStatusCode PostgresqlDriverInit(int version, void* raw_driver, if (version >= ADBC_VERSION_1_1_0) { std::memset(driver, 0, ADBC_DRIVER_1_1_0_SIZE); - driver->ErrorGetDetailCount = CommonErrorGetDetailCount; - driver->ErrorGetDetail = CommonErrorGetDetail; + driver->ErrorGetDetailCount = PostgresErrorGetDetailCount; + driver->ErrorGetDetail = PostgresErrorGetDetail; driver->ErrorFromArrayStream = PostgresErrorFromArrayStream; driver->DatabaseGetOption = PostgresDatabaseGetOption; diff --git a/c/driver/postgresql/postgresql_benchmark.cc b/c/driver/postgresql/postgresql_benchmark.cc index a2373095f3..aa22e033d3 100644 --- a/c/driver/postgresql/postgresql_benchmark.cc +++ b/c/driver/postgresql/postgresql_benchmark.cc @@ -20,7 +20,7 @@ #include #include -#include "adbc.h" +#include "arrow-adbc/adbc.h" #include "validation/adbc_validation_util.h" #define _ADBC_BENCHMARK_RETURN_NOT_OK_IMPL(NAME, EXPR) \ diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index 3c924d3917..be32bd893b 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -24,7 +24,7 @@ #include #include -#include +#include #include #include #include @@ -116,11 +116,24 @@ class PostgresQuirks : public adbc_validation::DriverQuirks { ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override { switch (ingest_type) { case NANOARROW_TYPE_INT8: + case NANOARROW_TYPE_UINT8: return NANOARROW_TYPE_INT16; + case NANOARROW_TYPE_UINT16: + return NANOARROW_TYPE_INT32; + case NANOARROW_TYPE_UINT32: + case NANOARROW_TYPE_UINT64: + return NANOARROW_TYPE_INT64; + case NANOARROW_TYPE_HALF_FLOAT: + return NANOARROW_TYPE_FLOAT; case NANOARROW_TYPE_DURATION: return NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO; case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: + return NANOARROW_TYPE_BINARY; case NANOARROW_TYPE_DECIMAL128: case NANOARROW_TYPE_DECIMAL256: return NANOARROW_TYPE_STRING; @@ -673,6 +686,47 @@ TEST_F(PostgresConnectionTest, MetadataSetCurrentDbSchema) { ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); } +TEST_F(PostgresConnectionTest, MetadataGetSchemaCaseSensitiveTable) { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + // Create sample table + { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, + "DROP TABLE IF EXISTS \"Uppercase\"", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement.value, "CREATE TABLE \"Uppercase\" (ints INT, strs TEXT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + // Check its schema + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, nullptr, nullptr, "Uppercase", + schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_NE(schema->release, nullptr); + ASSERT_STREQ(schema->format, "+s"); + ASSERT_EQ(schema->n_children, 2); + ASSERT_STREQ(schema->children[0]->format, "i"); + ASSERT_STREQ(schema->children[1]->format, "u"); + ASSERT_STREQ(schema->children[0]->name, "ints"); + ASSERT_STREQ(schema->children[1]->name, "strs"); + + // Do we have to release the connection here? +} + TEST_F(PostgresConnectionTest, MetadataGetStatistics) { if (!quirks()->supports_statistics()) { GTEST_SKIP(); @@ -845,11 +899,6 @@ class PostgresStatementTest : public ::testing::Test, void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); } void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); } - void TestSqlIngestUInt8() { GTEST_SKIP() << "Not implemented"; } - void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; } - void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; } - void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; } - void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet implemented"; } void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet implemented"; } void TestSqlPrepareSelectParams() { GTEST_SKIP() << "Not yet implemented"; } @@ -1098,10 +1147,11 @@ TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) { IsOkStatus(&error)); ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); - ASSERT_THAT(error.message, - ::testing::HasSubstr("Row #1 has value '9223372036854775807' which " - "exceeds PostgreSQL timestamp limits")); + IsStatus(ADBC_STATUS_INTERNAL, &error)); + ASSERT_THAT( + error.message, + ::testing::HasSubstr( + "Row 0 timestamp value 9223372036854775807 with unit 0 would overflow")); } { @@ -1128,10 +1178,11 @@ TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) { IsOkStatus(&error)); ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); - ASSERT_THAT(error.message, - ::testing::HasSubstr("Row #1 has value '-9223372036854775808' which " - "exceeds PostgreSQL timestamp limits")); + IsStatus(ADBC_STATUS_INTERNAL, &error)); + ASSERT_THAT( + error.message, + ::testing::HasSubstr( + "Row 0 timestamp value -9223372036854775808 with unit 0 would overflow")); } } @@ -1206,7 +1257,7 @@ TEST_F(PostgresStatementTest, UpdateInExecuteQuery) { ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, &reader.rows_affected, &error), IsOkStatus(&error)); - ASSERT_EQ(reader.rows_affected, -1); + ASSERT_EQ(reader.rows_affected, 2); ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); ASSERT_NO_FATAL_FAILURE(reader.Next()); ASSERT_EQ(reader.array->release, nullptr); @@ -1235,6 +1286,226 @@ TEST_F(PostgresStatementTest, UpdateInExecuteQuery) { } } +TEST_F(PostgresStatementTest, ExecuteSchemaParameterizedQuery) { + nanoarrow::UniqueSchema schema_bind; + ArrowSchemaInit(schema_bind.get()); + ASSERT_THAT(ArrowSchemaSetTypeStruct(schema_bind.get(), 1), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetType(schema_bind->children[0], NANOARROW_TYPE_STRING), + adbc_validation::IsOkErrno()); + + nanoarrow::UniqueArrayStream bind; + nanoarrow::EmptyArrayStream(schema_bind.get()).ToArrayStream(bind.get()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT $1", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBindStream(&statement, bind.get(), &error), IsOkStatus()); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_STREQ("u", schema->children[0]->format); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithResult) { + nanoarrow::UniqueSchema schema_bind; + ArrowSchemaInit(schema_bind.get()); + ASSERT_THAT(ArrowSchemaSetTypeStruct(schema_bind.get(), 1), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetType(schema_bind->children[0], NANOARROW_TYPE_INT32), + adbc_validation::IsOkErrno()); + + nanoarrow::UniqueArray bind; + ASSERT_THAT(ArrowArrayInitFromSchema(bind.get(), schema_bind.get(), nullptr), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayStartAppending(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayAppendInt(bind->children[0], 123), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishElement(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayAppendInt(bind->children[0], 456), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishElement(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayAppendNull(bind->children[0], 1), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishElement(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishBuildingDefault(bind.get(), nullptr), + adbc_validation::IsOkErrno()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT $1", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, bind.get(), schema_bind.get(), &error), + IsOkStatus()); + + { + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_EQ(reader.rows_affected, -1); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_EQ(reader.schema->n_children, 1); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 1); + ASSERT_EQ(reader.array_view->children[0]->buffer_views[1].data.as_int32[0], 123); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 1); + ASSERT_EQ(reader.array_view->children[0]->buffer_views[1].data.as_int32[0], 456); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 1); + ASSERT_EQ(reader.array->children[0]->null_count, 1); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } +} + +TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithRowsAffected) { + // Check that when executing one or more parameterized queries that the corresponding + // affected row count is added. + ASSERT_THAT(quirks()->DropTable(&connection, "adbc_test", &error), IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "CREATE TABLE adbc_test (ints INT)", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_EQ(reader.rows_affected, -1); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } + + { + // Use INSERT INTO + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "INSERT INTO adbc_test (ints) VALUES (123), (456)", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_EQ(reader.rows_affected, 2); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } + + nanoarrow::UniqueSchema schema_bind; + ArrowSchemaInit(schema_bind.get()); + ASSERT_THAT(ArrowSchemaSetTypeStruct(schema_bind.get(), 1), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetType(schema_bind->children[0], NANOARROW_TYPE_INT32), + adbc_validation::IsOkErrno()); + + nanoarrow::UniqueArray bind; + ASSERT_THAT(ArrowArrayInitFromSchema(bind.get(), schema_bind.get(), nullptr), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayStartAppending(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayAppendInt(bind->children[0], 123), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishElement(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayAppendInt(bind->children[0], 456), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishElement(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishBuildingDefault(bind.get(), nullptr), + adbc_validation::IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, + "DELETE FROM adbc_test WHERE ints = $1", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, bind.get(), schema_bind.get(), &error), + IsOkStatus()); + + { + int64_t rows_affected = -2; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_EQ(rows_affected, 2); + } + + { + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * from adbc_test", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } +} + +TEST_F(PostgresStatementTest, SqlExecuteCopyZeroRowOutputError) { + ASSERT_THAT(quirks()->DropTable(&connection, "adbc_test", &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + { + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement, "CREATE TABLE adbc_test (id int primary key, data jsonb)", + &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + } + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "insert into adbc_test (id, data) values (1, null)", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + } + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "insert into adbc_test (id, data) values (2, '1')", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + } + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, + "SELECT id, data from adbc_test JOIN " + "jsonb_array_elements(adbc_test.data) AS foo ON true", + &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus()); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_EQ(reader.MaybeNext(), EINVAL); + + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* detail = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(ADBC_STATUS_INVALID_ARGUMENT, status); + ASSERT_EQ("22023", std::string_view(detail->sqlstate, 5)); + } +} + TEST_F(PostgresStatementTest, BatchSizeHint) { ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "batch_size_hint_test", &error), IsOkStatus(&error)); @@ -1304,16 +1575,13 @@ TEST_F(PostgresStatementTest, AdbcErrorBackwardsCompatibility) { TEST_F(PostgresStatementTest, Cancel) { ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); - for (const char* query : { - "DROP TABLE IF EXISTS test_cancel", - "CREATE TABLE test_cancel (ints INT)", - R"(INSERT INTO test_cancel (ints) - SELECT g :: INT FROM GENERATE_SERIES(1, 65536) temp(g))", - }) { - ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - IsOkStatus(&error)); - } + const char* query = R"(DROP TABLE IF EXISTS test_cancel; + CREATE TABLE test_cancel (ints INT); + INSERT INTO test_cancel (ints) + SELECT g :: INT FROM GENERATE_SERIES(1, 65536) temp(g);)"; + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM test_cancel", &error), IsOkStatus(&error)); @@ -1340,6 +1608,91 @@ TEST_F(PostgresStatementTest, Cancel) { ASSERT_NE(0, AdbcErrorGetDetailCount(detail)); } +TEST_F(PostgresStatementTest, MultipleStatementsSingleQuery) { + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + const char* query = R"(DROP TABLE IF EXISTS test_query_statements; + CREATE TABLE test_query_statements (ints INT); + INSERT INTO test_query_statements VALUES((1)); + INSERT INTO test_query_statements VALUES((2)); + INSERT INTO test_query_statements VALUES((3));)"; + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "SELECT * FROM test_query_statements", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + reader.GetSchema(); + ASSERT_THAT(reader.MaybeNext(), adbc_validation::IsOkErrno()); + ASSERT_EQ(reader.array->length, 3); +} + +TEST_F(PostgresStatementTest, SetUseCopyFalse) { + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + const char* query = R"(DROP TABLE IF EXISTS test_query_set_copy_false; + CREATE TABLE test_query_set_copy_false (ints INT); + INSERT INTO test_query_set_copy_false VALUES((1)); + INSERT INTO test_query_set_copy_false VALUES((NULL)); + INSERT INTO test_query_set_copy_false VALUES((3));)"; + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + // Check option setting/getting + ASSERT_EQ( + adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error), + "true"); + + ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy", + "not true or false", &error), + IsStatus(ADBC_STATUS_INVALID_ARGUMENT)); + + ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy", + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_EQ( + adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error), + "true"); + + ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy", + ADBC_OPTION_VALUE_DISABLED, &error), + IsOkStatus(&error)); + ASSERT_EQ( + adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error), + "false"); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, + "SELECT * FROM test_query_set_copy_false", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + + ASSERT_EQ(reader.rows_affected, 3); + + reader.GetSchema(); + ASSERT_EQ(reader.schema->n_children, 1); + ASSERT_STREQ(reader.schema->children[0]->format, "i"); + ASSERT_STREQ(reader.schema->children[0]->name, "ints"); + + ASSERT_THAT(reader.MaybeNext(), adbc_validation::IsOkErrno()); + ASSERT_EQ(reader.array->length, 3); + ASSERT_EQ(reader.array->n_children, 1); + ASSERT_EQ(reader.array->children[0]->null_count, 1); + + ASSERT_THAT(reader.MaybeNext(), adbc_validation::IsOkErrno()); + ASSERT_EQ(reader.array->release, nullptr); +} + struct TypeTestCase { std::string name; std::string sql_type; @@ -1436,7 +1789,7 @@ TEST_P(PostgresTypeTest, SelectValue) { // check type ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( - &reader.schema.value, {{std::nullopt, GetParam().arrow_type, true}})); + &reader.schema.value, {{"", GetParam().arrow_type, true}})); if (GetParam().arrow_type == NANOARROW_TYPE_TIMESTAMP) { if (GetParam().sql_type.find("WITH TIME ZONE") == std::string::npos) { ASSERT_STREQ(reader.schema->children[0]->format, "tsu:"); @@ -1906,9 +2259,8 @@ TEST_P(PostgresDecimalTest, SelectValue) { ArrowSchemaInit(&schema.value); ASSERT_EQ(ArrowSchemaSetTypeStruct(&schema.value, 1), 0); - ASSERT_EQ( - PrivateArrowSchemaSetTypeDecimal(schema.value.children[0], type, precision, scale), - 0); + ASSERT_EQ(ArrowSchemaSetTypeDecimal(schema.value.children[0], type, precision, scale), + 0); ASSERT_EQ(ArrowSchemaSetName(schema.value.children[0], "col"), 0); ASSERT_THAT(adbc_validation::MakeBatch(&schema.value, &array.value, diff --git a/c/driver/postgresql/result_helper.cc b/c/driver/postgresql/result_helper.cc index ad5a54e00d..48c6804883 100644 --- a/c/driver/postgresql/result_helper.cc +++ b/c/driver/postgresql/result_helper.cc @@ -17,52 +17,195 @@ #include "result_helper.h" -#include "driver/common/utils.h" +#include +#include + +#define ADBC_FRAMEWORK_USE_FMT +#include "driver/framework/status.h" #include "error.h" namespace adbcpq { -PqResultHelper::~PqResultHelper() { - if (result_ != nullptr) { - PQclear(result_); - } -} +PqResultHelper::~PqResultHelper() { ClearResult(); } -AdbcStatusCode PqResultHelper::Prepare() { +Status PqResultHelper::PrepareInternal(int n_params, const Oid* param_oids) const { // TODO: make stmtName a unique identifier? PGresult* result = - PQprepare(conn_, /*stmtName=*/"", query_.c_str(), param_values_.size(), NULL); + PQprepare(conn_, /*stmtName=*/"", query_.c_str(), n_params, param_oids); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error_, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(conn_), query_.c_str()); + auto status = MakeStatus(result, "Failed to prepare query: {}\nQuery was:{}", + PQerrorMessage(conn_), query_.c_str()); PQclear(result); - return code; + return status; } PQclear(result); - return ADBC_STATUS_OK; + return Status::Ok(); } -AdbcStatusCode PqResultHelper::Execute() { - std::vector param_c_strs; +Status PqResultHelper::Prepare() const { return PrepareInternal(0, nullptr); } - for (size_t index = 0; index < param_values_.size(); index++) { - param_c_strs.push_back(param_values_[index].c_str()); +Status PqResultHelper::Prepare(const std::vector& param_oids) const { + return PrepareInternal(static_cast(param_oids.size()), param_oids.data()); +} + +Status PqResultHelper::DescribePrepared() { + ClearResult(); + result_ = PQdescribePrepared(conn_, /*stmtName=*/""); + if (PQresultStatus(result_) != PGRES_COMMAND_OK) { + Status status = MakeStatus( + result_, "[libpq] Failed to describe prepared statement: {}\nQuery was:{}", + PQerrorMessage(conn_), query_.c_str()); + ClearResult(); + return status; } - result_ = - PQexecPrepared(conn_, "", param_values_.size(), param_c_strs.data(), NULL, NULL, 0); + return Status::Ok(); +} + +Status PqResultHelper::Execute(const std::vector& params, + PostgresType* param_types) { + if (params.size() == 0 && param_types == nullptr && output_format_ == Format::kText) { + ClearResult(); + result_ = PQexec(conn_, query_.c_str()); + } else { + std::vector param_values; + std::vector param_lengths; + std::vector param_formats; + + for (const auto& param : params) { + param_values.push_back(param.data()); + param_lengths.push_back(static_cast(param.size())); + param_formats.push_back(static_cast(param_format_)); + } + + std::vector param_oids; + const Oid* param_oids_ptr = nullptr; + if (param_types != nullptr) { + param_oids.resize(params.size()); + for (size_t i = 0; i < params.size(); i++) { + param_oids[i] = param_types->child(i).oid(); + } + param_oids_ptr = param_oids.data(); + } + + ClearResult(); + result_ = PQexecParams(conn_, query_.c_str(), static_cast(param_values.size()), + param_oids_ptr, param_values.data(), param_lengths.data(), + param_formats.data(), static_cast(output_format_)); + } ExecStatusType status = PQresultStatus(result_); if (status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK) { - AdbcStatusCode error = - SetError(error_, result_, "[libpq] Failed to execute query '%s': %s", - query_.c_str(), PQerrorMessage(conn_)); - return error; + return MakeStatus(result_, "[libpq] Failed to execute query '{}': {}", query_.c_str(), + PQerrorMessage(conn_)); + } + + return Status::Ok(); +} + +Status PqResultHelper::ExecuteCopy() { + // Remove trailing semicolon(s) from the query before feeding it into COPY + while (!query_.empty() && query_.back() == ';') { + query_.pop_back(); + } + + std::string copy_query = "COPY (" + query_ + ") TO STDOUT (FORMAT binary)"; + ClearResult(); + result_ = PQexecParams(conn_, copy_query.c_str(), /*nParams=*/0, + /*paramTypes=*/nullptr, /*paramValues=*/nullptr, + /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, + static_cast(Format::kBinary)); + + if (PQresultStatus(result_) != PGRES_COPY_OUT) { + Status status = MakeStatus( + result_, + "[libpq] Failed to execute query: could not begin COPY: {}\nQuery was: {}", + PQerrorMessage(conn_), copy_query.c_str()); + ClearResult(); + return status; + } + + return Status::Ok(); +} + +Status PqResultHelper::ResolveParamTypes(PostgresTypeResolver& type_resolver, + PostgresType* param_types) { + struct ArrowError na_error; + ArrowErrorInit(&na_error); + + const int num_params = PQnparams(result_); + PostgresType root_type(PostgresTypeId::kRecord); + + for (int i = 0; i < num_params; i++) { + const Oid pg_oid = PQparamtype(result_, i); + PostgresType pg_type; + if (type_resolver.Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) { + Status status = Status::NotImplemented("[libpq] Parameter #", i + 1, " (\"", + PQfname(result_, i), + "\") has unknown type code ", pg_oid); + ClearResult(); + return status; + } + + root_type.AppendChild(PQfname(result_, i), pg_type); + } + + *param_types = root_type; + return Status::Ok(); +} + +Status PqResultHelper::ResolveOutputTypes(PostgresTypeResolver& type_resolver, + PostgresType* result_types) { + struct ArrowError na_error; + ArrowErrorInit(&na_error); + + const int num_fields = PQnfields(result_); + PostgresType root_type(PostgresTypeId::kRecord); + + for (int i = 0; i < num_fields; i++) { + const Oid pg_oid = PQftype(result_, i); + PostgresType pg_type; + if (type_resolver.Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) { + Status status = + Status::NotImplemented("[libpq] Column #", i + 1, " (\"", PQfname(result_, i), + "\") has unknown type code ", pg_oid); + ClearResult(); + return status; + } + + root_type.AppendChild(PQfname(result_, i), pg_type); } - return ADBC_STATUS_OK; + *result_types = root_type; + return Status::Ok(); +} + +PGresult* PqResultHelper::ReleaseResult() { + PGresult* out = result_; + result_ = nullptr; + return out; +} + +int64_t PqResultHelper::AffectedRows() const { + if (result_ == nullptr) { + return -1; + } + + char* first = PQcmdTuples(result_); + char* last = first + strlen(first); + if ((last - first) == 0) { + return -1; + } + + int64_t out; + auto result = std::from_chars(first, last, out); + + if (result.ec == std::errc() && result.ptr == last) { + return out; + } else { + return -1; + } } } // namespace adbcpq diff --git a/c/driver/postgresql/result_helper.h b/c/driver/postgresql/result_helper.h index 8eec8dc347..1f3f93c46b 100644 --- a/c/driver/postgresql/result_helper.h +++ b/c/driver/postgresql/result_helper.h @@ -18,14 +18,23 @@ #pragma once #include +#include +#include +#include #include #include #include #include -#include +#include #include +#include "copy/reader.h" +#include "driver/framework/status.h" + +using adbc::driver::Result; +using adbc::driver::Status; + namespace adbcpq { /// \brief A single column in a single row of a result set. @@ -42,18 +51,46 @@ struct PqRecord { } return result; } + + Result ParseInteger() const { + const char* last = data + len; + int64_t value = 0; + auto result = std::from_chars(data, last, value, 10); + if (result.ec == std::errc() && result.ptr == last) { + return value; + } else { + return Status::Internal("Can't parse '", data, "' as integer"); + } + } + + Result> ParseTextArray() const { + std::string text_array(data, len); + text_array.erase(0, 1); + text_array.erase(text_array.size() - 1); + + std::vector elements; + std::stringstream ss(std::move(text_array)); + std::string tmp; + + while (getline(ss, tmp, ',')) { + elements.push_back(std::move(tmp)); + } + + return elements; + } + + std::string_view value() { return std::string_view(data, len); } }; // Used by PqResultHelper to provide index-based access to the records within each // row of a PGresult class PqResultRow { public: - PqResultRow(PGresult* result, int row_num) : result_(result), row_num_(row_num) { - ncols_ = PQnfields(result); - } + PqResultRow() : result_(nullptr), row_num_(-1) {} + PqResultRow(PGresult* result, int row_num) : result_(result), row_num_(row_num) {} - PqRecord operator[](const int& col_num) { - assert(col_num < ncols_); + PqRecord operator[](int col_num) const { + assert(col_num < PQnfields(result_)); const char* data = PQgetvalue(result_, row_num_, col_num); const int len = PQgetlength(result_, row_num_, col_num); const bool is_null = PQgetisnull(result_, row_num_, col_num); @@ -61,10 +98,15 @@ class PqResultRow { return PqRecord{data, len, is_null}; } + bool IsValid() const { + return result_ && row_num_ >= 0 && row_num_ < PQntuples(result_); + } + + PqResultRow Next() const { return PqResultRow(result_, row_num_ + 1); } + private: PGresult* result_ = nullptr; int row_num_; - int ncols_; }; // Helper to manager the lifecycle of a PQResult. The query argument @@ -73,25 +115,62 @@ class PqResultRow { // prior to iterating class PqResultHelper { public: - explicit PqResultHelper(PGconn* conn, std::string query, struct AdbcError* error) - : conn_(conn), query_(std::move(query)), error_(error) {} + enum class Format { + kText = 0, + kBinary = 1, + }; - explicit PqResultHelper(PGconn* conn, std::string query, - std::vector param_values, struct AdbcError* error) - : conn_(conn), - query_(std::move(query)), - param_values_(std::move(param_values)), - error_(error) {} + explicit PqResultHelper(PGconn* conn, std::string query) + : conn_(conn), query_(std::move(query)) {} + + PqResultHelper(PqResultHelper&& other) + : PqResultHelper(other.conn_, std::move(other.query_)) { + result_ = other.result_; + other.result_ = nullptr; + } ~PqResultHelper(); - AdbcStatusCode Prepare(); - AdbcStatusCode Execute(); + void set_param_format(Format format) { param_format_ = format; } + void set_output_format(Format format) { output_format_ = format; } + + Status Prepare() const; + Status Prepare(const std::vector& param_oids) const; + Status DescribePrepared(); + Status Execute(const std::vector& params = {}, + PostgresType* param_types = nullptr); + Status ExecuteCopy(); + Status ResolveParamTypes(PostgresTypeResolver& type_resolver, + PostgresType* param_types); + Status ResolveOutputTypes(PostgresTypeResolver& type_resolver, + PostgresType* result_types); + + bool HasResult() const { return result_ != nullptr; } + + void SetResult(PGresult* result) { + ClearResult(); + result_ = result; + } + + PGresult* ReleaseResult(); + + void ClearResult() { + PQclear(result_); + result_ = nullptr; + } + + int64_t AffectedRows() const; int NumRows() const { return PQntuples(result_); } int NumColumns() const { return PQnfields(result_); } + const char* FieldName(int column_number) const { + return PQfname(result_, column_number); + } + Oid FieldType(int column_number) const { return PQftype(result_, column_number); } + PqResultRow Row(int i) const { return PqResultRow(result_, i); } + class iterator { const PqResultHelper& outer_; int curr_row_ = 0; @@ -112,7 +191,7 @@ class PqResultHelper { return outer_.result_ == other.outer_.result_ && curr_row_ == other.curr_row_; } bool operator!=(iterator other) const { return !(*this == other); } - PqResultRow operator*() { return PqResultRow(outer_.result_, curr_row_); } + PqResultRow operator*() const { return PqResultRow(outer_.result_, curr_row_); } using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; using value_type = std::vector; @@ -120,14 +199,17 @@ class PqResultHelper { using reference = const std::vector&; }; - iterator begin() { return iterator(*this); } - iterator end() { return iterator(*this, NumRows()); } + iterator begin() const { return iterator(*this); } + iterator end() const { return iterator(*this, NumRows()); } private: PGresult* result_ = nullptr; PGconn* conn_; std::string query_; - std::vector param_values_; - struct AdbcError* error_; + Format param_format_ = Format::kText; + Format output_format_ = Format::kText; + + Status PrepareInternal(int n_params, const Oid* param_oids) const; }; + } // namespace adbcpq diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc new file mode 100644 index 0000000000..61d17bb038 --- /dev/null +++ b/c/driver/postgresql/result_reader.cc @@ -0,0 +1,272 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "result_reader.h" + +#include +#include + +#include "copy/reader.h" +#include "driver/framework/status.h" + +namespace adbcpq { + +int PqResultArrayReader::GetSchema(struct ArrowSchema* out) { + ResetErrors(); + + if (schema_->release == nullptr) { + Status status = Initialize(nullptr); + if (!status.ok()) { + status.ToAdbc(&error_); + return EINVAL; + } + } + + return ArrowSchemaDeepCopy(schema_.get(), out); +} + +int PqResultArrayReader::GetNext(struct ArrowArray* out) { + ResetErrors(); + + Status status; + if (schema_->release == nullptr) { + status = Initialize(nullptr); + if (!status.ok()) { + status.ToAdbc(&error_); + return EINVAL; + } + } + + // If don't already have a result, populate it by binding the next row + // in the bind stream. If this is the first call to GetNext(), we have + // already populated the result. + if (!helper_.HasResult()) { + // If there was no bind stream provided or the existing bind stream has been + // exhausted, we are done. + if (!bind_stream_) { + out->release = nullptr; + return NANOARROW_OK; + } + + // Keep binding and executing until we have a result to return + status = BindNextAndExecute(nullptr); + if (!status.ok()) { + status.ToAdbc(&error_); + return EIO; + } + + // It's possible that there is still nothing to do here + if (!helper_.HasResult()) { + out->release = nullptr; + return NANOARROW_OK; + } + } + + nanoarrow::UniqueArray tmp; + NANOARROW_RETURN_NOT_OK(ArrowArrayInitFromSchema(tmp.get(), schema_.get(), &na_error_)); + NANOARROW_RETURN_NOT_OK(ArrowArrayStartAppending(tmp.get())); + for (int i = 0; i < helper_.NumColumns(); i++) { + NANOARROW_RETURN_NOT_OK(field_readers_[i]->InitArray(tmp->children[i])); + } + + // TODO: If we get an EOVERFLOW here (e.g., big string data), we + // would need to keep track of what row number we're on and start + // from there instead of begin() on the next call. We could also + // respect the size hint here to chunk the batches. + struct ArrowBufferView item; + for (auto it = helper_.begin(); it != helper_.end(); it++) { + auto row = *it; + for (int i = 0; i < helper_.NumColumns(); i++) { + auto pg_item = row[i]; + item.data.data = pg_item.data; + + if (pg_item.is_null) { + item.size_bytes = -1; + } else { + item.size_bytes = pg_item.len; + } + + NANOARROW_RETURN_NOT_OK(field_readers_[i]->Read( + &item, static_cast(item.size_bytes), tmp->children[i], &na_error_)); + } + } + + for (int i = 0; i < helper_.NumColumns(); i++) { + NANOARROW_RETURN_NOT_OK(field_readers_[i]->FinishArray(tmp->children[i], &na_error_)); + } + + tmp->length = helper_.NumRows(); + tmp->null_count = 0; + NANOARROW_RETURN_NOT_OK(ArrowArrayFinishBuildingDefault(tmp.get(), &na_error_)); + + // Signal that the next call to GetNext() will have to populate the result again + helper_.ClearResult(); + + // Canonically return zero-size results as an empty stream + if (tmp->length == 0) { + out->release = nullptr; + return NANOARROW_OK; + } + + ArrowArrayMove(tmp.get(), out); + return NANOARROW_OK; +} + +const char* PqResultArrayReader::GetLastError() { + if (error_.message != nullptr) { + return error_.message; + } else { + return na_error_.message; + } +} + +Status PqResultArrayReader::Initialize(int64_t* rows_affected) { + helper_.set_output_format(PqResultHelper::Format::kBinary); + helper_.set_param_format(PqResultHelper::Format::kBinary); + + // If we have to do binding, set up the bind stream an execute until + // there is a result with more than zero rows to populate. + if (bind_stream_) { + UNWRAP_STATUS(bind_stream_->Begin([] { return Status::Ok(); })); + + UNWRAP_STATUS(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_)); + UNWRAP_STATUS(helper_.Prepare(bind_stream_->param_types)); + UNWRAP_STATUS(BindNextAndExecute(nullptr)); + + // If there were no arrays in the bind stream, we still need a result + // to populate the schema. If there were any arrays in the bind stream, + // the last one will still be in helper_ even if it had zero rows. + if (!helper_.HasResult()) { + UNWRAP_STATUS(helper_.DescribePrepared()); + } + + // We can't provide affected row counts if there is a bind stream and + // an output because we don't know how many future bind arrays/rows there + // might be. + if (rows_affected != nullptr) { + *rows_affected = -1; + } + } else { + UNWRAP_STATUS(helper_.Execute()); + if (rows_affected != nullptr) { + *rows_affected = helper_.AffectedRows(); + } + } + + // Build the schema for which we are about to build results + ArrowSchemaInit(schema_.get()); + UNWRAP_NANOARROW(na_error_, Internal, + ArrowSchemaSetTypeStruct(schema_.get(), helper_.NumColumns())); + + for (int i = 0; i < helper_.NumColumns(); i++) { + PostgresType child_type; + UNWRAP_ERRNO(Internal, + type_resolver_->FindWithDefault(helper_.FieldType(i), &child_type)); + + UNWRAP_ERRNO(Internal, child_type.SetSchema(schema_->children[i], vendor_name_)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(schema_->children[i], helper_.FieldName(i))); + + std::unique_ptr child_reader; + UNWRAP_NANOARROW( + na_error_, Internal, + MakeCopyFieldReader(child_type, schema_->children[i], &child_reader, &na_error_)); + + child_reader->Init(child_type); + UNWRAP_NANOARROW(na_error_, Internal, child_reader->InitSchema(schema_->children[i])); + + field_readers_.push_back(std::move(child_reader)); + } + + return Status::Ok(); +} + +Status PqResultArrayReader::ToArrayStream(int64_t* affected_rows, + struct ArrowArrayStream* out) { + if (out == nullptr) { + // If there is no output requested, we still need to execute and + // set affected_rows if needed. We don't need an output schema or to set up a copy + // reader, so we can skip those steps by going straight to Execute(). This also + // enables us to support queries with multiple statements because we can call PQexec() + // instead of PQexecParams(). + UNWRAP_STATUS(ExecuteAll(affected_rows)); + return Status::Ok(); + } + + // Otherwise, execute until we have a result to return. We need this to provide row + // counts for DELETE and CREATE TABLE queries as well as to provide more informative + // errors until this reader class is wired up to provide extended AdbcError information. + UNWRAP_STATUS(Initialize(affected_rows)); + + nanoarrow::ArrayStreamFactory::InitArrayStream( + new PqResultArrayReader(this), out); + + return Status::Ok(); +} + +Status PqResultArrayReader::BindNextAndExecute(int64_t* affected_rows) { + // Keep pulling from the bind stream and executing as long as + // we receive results with zero rows. + do { + UNWRAP_STATUS(bind_stream_->EnsureNextRow()); + + if (!bind_stream_->current->release) { + UNWRAP_STATUS(bind_stream_->Cleanup(conn_)); + bind_stream_.reset(); + return Status::Ok(); + } + + PGresult* result; + UNWRAP_STATUS(bind_stream_->BindAndExecuteCurrentRow( + conn_, &result, /*result_format*/ kPgBinaryFormat)); + helper_.SetResult(result); + if (affected_rows) { + (*affected_rows) += helper_.AffectedRows(); + } + } while (helper_.NumRows() == 0); + + return Status::Ok(); +} + +Status PqResultArrayReader::ExecuteAll(int64_t* affected_rows) { + // For the case where we don't need a result, we either need to exhaust the bind + // stream (if there is one) or execute the query without binding. + if (bind_stream_) { + UNWRAP_STATUS(bind_stream_->Begin([] { return Status::Ok(); })); + UNWRAP_STATUS(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_)); + UNWRAP_STATUS(helper_.Prepare(bind_stream_->param_types)); + + // Reset affected rows to zero before binding and executing any + if (affected_rows) { + (*affected_rows) = 0; + } + + do { + UNWRAP_STATUS(BindNextAndExecute(affected_rows)); + } while (bind_stream_); + } else { + UNWRAP_STATUS(helper_.Execute()); + + if (affected_rows != nullptr) { + *affected_rows = helper_.AffectedRows(); + } + } + + return Status::Ok(); +} + +} // namespace adbcpq diff --git a/c/driver/postgresql/result_reader.h b/c/driver/postgresql/result_reader.h new file mode 100644 index 0000000000..90b35baf06 --- /dev/null +++ b/c/driver/postgresql/result_reader.h @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#if !defined(NOMINMAX) +#define NOMINMAX +#endif + +#include +#include +#include +#include + +#include + +#include "bind_stream.h" +#include "copy/reader.h" +#include "result_helper.h" + +namespace adbcpq { + +class PqResultArrayReader { + public: + PqResultArrayReader(PGconn* conn, std::shared_ptr type_resolver, + std::string query) + : conn_(conn), + helper_(conn, std::move(query)), + type_resolver_(type_resolver), + autocommit_(false) { + ArrowErrorInit(&na_error_); + error_ = ADBC_ERROR_INIT; + } + + ~PqResultArrayReader() { ResetErrors(); } + + // Ensure the reader knows what the autocommit status was on creation. This is used + // so that the temporary timezone setting required for parameter binding can be wrapped + // in a transaction (or not) accordingly. + void SetAutocommit(bool autocommit) { autocommit_ = autocommit; } + + void SetBind(struct ArrowArrayStream* stream) { + bind_stream_ = std::make_unique(); + bind_stream_->SetBind(stream); + } + + void SetVendorName(std::string_view vendor_name) { + vendor_name_ = std::string(vendor_name); + } + + int GetSchema(struct ArrowSchema* out); + int GetNext(struct ArrowArray* out); + const char* GetLastError(); + + Status ToArrayStream(int64_t* affected_rows, struct ArrowArrayStream* out); + + Status Initialize(int64_t* affected_rows); + + private: + PGconn* conn_; + PqResultHelper helper_; + std::unique_ptr bind_stream_; + std::shared_ptr type_resolver_; + std::vector> field_readers_; + nanoarrow::UniqueSchema schema_; + bool autocommit_; + std::string vendor_name_; + struct AdbcError error_; + struct ArrowError na_error_; + + explicit PqResultArrayReader(PqResultArrayReader* other) + : conn_(other->conn_), + helper_(std::move(other->helper_)), + bind_stream_(std::move(other->bind_stream_)), + type_resolver_(std::move(other->type_resolver_)), + field_readers_(std::move(other->field_readers_)), + schema_(std::move(other->schema_)) { + ArrowErrorInit(&na_error_); + error_ = ADBC_ERROR_INIT; + } + + Status BindNextAndExecute(int64_t* affected_rows); + Status ExecuteAll(int64_t* affected_rows); + + void ResetErrors() { + ArrowErrorInit(&na_error_); + + if (error_.private_data != nullptr) { + error_.release(&error_); + } + error_ = ADBC_ERROR_INIT; + } +}; + +} // namespace adbcpq diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 93e88ee82c..9518e378eb 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -20,6 +20,7 @@ #include "statement.h" +#include #include #include #include @@ -31,614 +32,26 @@ #include #include -#include +#include #include #include +#include "bind_stream.h" #include "connection.h" -#include "copy/writer.h" #include "driver/common/options.h" #include "driver/common/utils.h" +#include "driver/framework/utility.h" #include "error.h" #include "postgres_type.h" #include "postgres_util.h" #include "result_helper.h" +#include "result_reader.h" namespace adbcpq { -namespace { -/// The flag indicating to PostgreSQL that we want binary-format values. -constexpr int kPgBinaryFormat = 1; - -/// One-value ArrowArrayStream used to unify the implementations of Bind -struct OneValueStream { - struct ArrowSchema schema; - struct ArrowArray array; - - static int GetSchema(struct ArrowArrayStream* self, struct ArrowSchema* out) { - OneValueStream* stream = static_cast(self->private_data); - return ArrowSchemaDeepCopy(&stream->schema, out); - } - static int GetNext(struct ArrowArrayStream* self, struct ArrowArray* out) { - OneValueStream* stream = static_cast(self->private_data); - *out = stream->array; - stream->array.release = nullptr; - return 0; - } - static const char* GetLastError(struct ArrowArrayStream* self) { return NULL; } - static void Release(struct ArrowArrayStream* self) { - OneValueStream* stream = static_cast(self->private_data); - if (stream->schema.release) { - stream->schema.release(&stream->schema); - stream->schema.release = nullptr; - } - if (stream->array.release) { - stream->array.release(&stream->array); - stream->array.release = nullptr; - } - delete stream; - self->release = nullptr; - } -}; - -/// Build an PostgresType object from a PGresult* -AdbcStatusCode ResolvePostgresType(const PostgresTypeResolver& type_resolver, - PGresult* result, PostgresType* out, - struct AdbcError* error) { - ArrowError na_error; - const int num_fields = PQnfields(result); - PostgresType root_type(PostgresTypeId::kRecord); - - for (int i = 0; i < num_fields; i++) { - const Oid pg_oid = PQftype(result, i); - PostgresType pg_type; - if (type_resolver.Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) { - SetError(error, "%s%d%s%s%s%d", "[libpq] Column #", i + 1, " (\"", - PQfname(result, i), "\") has unknown type code ", pg_oid); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - root_type.AppendChild(PQfname(result, i), pg_type); - } - - *out = root_type; - return ADBC_STATUS_OK; -} - -/// Helper to manage bind parameters with a prepared statement -struct BindStream { - Handle bind; - Handle bind_schema; - struct ArrowSchemaView bind_schema_view; - std::vector bind_schema_fields; - - // OIDs for parameter types - std::vector param_types; - std::vector param_values; - std::vector param_lengths; - std::vector param_formats; - std::vector param_values_offsets; - std::vector param_values_buffer; - // XXX: this assumes fixed-length fields only - will need more - // consideration to deal with variable-length fields - - bool has_tz_field = false; - std::string tz_setting; - - struct ArrowError na_error; - - explicit BindStream(struct ArrowArrayStream&& bind) { - this->bind.value = std::move(bind); - std::memset(&na_error, 0, sizeof(na_error)); - } - - template - AdbcStatusCode Begin(Callback&& callback, struct AdbcError* error) { - CHECK_NA(INTERNAL, bind->get_schema(&bind.value, &bind_schema.value), error); - CHECK_NA( - INTERNAL, - ArrowSchemaViewInit(&bind_schema_view, &bind_schema.value, /*error*/ nullptr), - error); - - if (bind_schema_view.type != ArrowType::NANOARROW_TYPE_STRUCT) { - SetError(error, "%s", "[libpq] Bind parameters must have type STRUCT"); - return ADBC_STATUS_INVALID_STATE; - } - - bind_schema_fields.resize(bind_schema->n_children); - for (size_t i = 0; i < bind_schema_fields.size(); i++) { - CHECK_NA(INTERNAL, - ArrowSchemaViewInit(&bind_schema_fields[i], bind_schema->children[i], - /*error*/ nullptr), - error); - } - - return std::move(callback)(); - } - - AdbcStatusCode SetParamTypes(const PostgresTypeResolver& type_resolver, - struct AdbcError* error) { - param_types.resize(bind_schema->n_children); - param_values.resize(bind_schema->n_children); - param_lengths.resize(bind_schema->n_children); - param_formats.resize(bind_schema->n_children, kPgBinaryFormat); - param_values_offsets.reserve(bind_schema->n_children); - - for (size_t i = 0; i < bind_schema_fields.size(); i++) { - PostgresTypeId type_id; - switch (bind_schema_fields[i].type) { - case ArrowType::NANOARROW_TYPE_BOOL: - type_id = PostgresTypeId::kBool; - param_lengths[i] = 1; - break; - case ArrowType::NANOARROW_TYPE_INT8: - case ArrowType::NANOARROW_TYPE_INT16: - type_id = PostgresTypeId::kInt2; - param_lengths[i] = 2; - break; - case ArrowType::NANOARROW_TYPE_INT32: - type_id = PostgresTypeId::kInt4; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_INT64: - type_id = PostgresTypeId::kInt8; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_FLOAT: - type_id = PostgresTypeId::kFloat4; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_DOUBLE: - type_id = PostgresTypeId::kFloat8; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - type_id = PostgresTypeId::kText; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_BINARY: - type_id = PostgresTypeId::kBytea; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_DATE32: - type_id = PostgresTypeId::kDate; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_TIMESTAMP: - type_id = PostgresTypeId::kTimestamp; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - type_id = PostgresTypeId::kInterval; - param_lengths[i] = 16; - break; - case ArrowType::NANOARROW_TYPE_DECIMAL128: - case ArrowType::NANOARROW_TYPE_DECIMAL256: - type_id = PostgresTypeId::kNumeric; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_DICTIONARY: { - struct ArrowSchemaView value_view; - CHECK_NA(INTERNAL, - ArrowSchemaViewInit(&value_view, bind_schema->children[i]->dictionary, - nullptr), - error); - switch (value_view.type) { - case NANOARROW_TYPE_BINARY: - case NANOARROW_TYPE_LARGE_BINARY: - type_id = PostgresTypeId::kBytea; - param_lengths[i] = 0; - break; - case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_STRING: - type_id = PostgresTypeId::kText; - param_lengths[i] = 0; - break; - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", - bind_schema->children[i]->name, - "') has unsupported dictionary value parameter type ", - ArrowTypeString(value_view.type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - break; - } - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", bind_schema->children[i]->name, - "') has unsupported parameter type ", - ArrowTypeString(bind_schema_fields[i].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - param_types[i] = type_resolver.GetOID(type_id); - if (param_types[i] == 0) { - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", bind_schema->children[i]->name, - "') has type with no corresponding PostgreSQL type ", - ArrowTypeString(bind_schema_fields[i].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - } - - size_t param_values_length = 0; - for (int length : param_lengths) { - param_values_offsets.push_back(param_values_length); - param_values_length += length; - } - param_values_buffer.resize(param_values_length); - return ADBC_STATUS_OK; - } - - AdbcStatusCode Prepare(PGconn* conn, const std::string& query, struct AdbcError* error, - const bool autocommit) { - // tz-aware timestamps require special handling to set the timezone to UTC - // prior to sending over the binary protocol; must be reset after execute - for (int64_t col = 0; col < bind_schema->n_children; col++) { - if ((bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) && - (strcmp("", bind_schema_fields[col].timezone))) { - has_tz_field = true; - - if (autocommit) { - PGresult* begin_result = PQexec(conn, "BEGIN"); - if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, begin_result, - "[libpq] Failed to begin transaction for timezone data: %s", - PQerrorMessage(conn)); - PQclear(begin_result); - return code; - } - PQclear(begin_result); - } - - PGresult* get_tz_result = PQexec(conn, "SELECT current_setting('TIMEZONE')"); - if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) { - AdbcStatusCode code = SetError(error, get_tz_result, - "[libpq] Could not query current timezone: %s", - PQerrorMessage(conn)); - PQclear(get_tz_result); - return code; - } - - tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0)); - PQclear(get_tz_result); - - PGresult* set_utc_result = PQexec(conn, "SET TIME ZONE 'UTC'"); - if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = SetError(error, set_utc_result, - "[libpq] Failed to set time zone to UTC: %s", - PQerrorMessage(conn)); - PQclear(set_utc_result); - return code; - } - PQclear(set_utc_result); - break; - } - } - - PGresult* result = PQprepare(conn, /*stmtName=*/"", query.c_str(), - /*nParams=*/bind_schema->n_children, param_types.data()); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(conn), query.c_str()); - PQclear(result); - return code; - } - PQclear(result); - return ADBC_STATUS_OK; - } - - AdbcStatusCode Execute(PGconn* conn, int64_t* rows_affected, struct AdbcError* error) { - if (rows_affected) *rows_affected = 0; - PGresult* result = nullptr; - - while (true) { - Handle array; - int res = bind->get_next(&bind.value, &array.value); - if (res != 0) { - SetError(error, - "[libpq] Failed to read next batch from stream of bind parameters: " - "(%d) %s %s", - res, std::strerror(res), bind->get_last_error(&bind.value)); - return ADBC_STATUS_IO; - } - if (!array->release) break; - - Handle array_view; - // TODO: include error messages - CHECK_NA( - INTERNAL, - ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, nullptr), - error); - CHECK_NA(INTERNAL, ArrowArrayViewSetArray(&array_view.value, &array.value, nullptr), - error); - - for (int64_t row = 0; row < array->length; row++) { - for (int64_t col = 0; col < array_view->n_children; col++) { - if (ArrowArrayViewIsNull(array_view->children[col], row)) { - param_values[col] = nullptr; - continue; - } else { - param_values[col] = param_values_buffer.data() + param_values_offsets[col]; - } - switch (bind_schema_fields[col].type) { - case ArrowType::NANOARROW_TYPE_BOOL: { - const int8_t val = ArrowBitGet( - array_view->children[col]->buffer_views[1].data.as_uint8, row); - std::memcpy(param_values[col], &val, sizeof(int8_t)); - break; - } - - case ArrowType::NANOARROW_TYPE_INT8: { - const int16_t val = - array_view->children[col]->buffer_views[1].data.as_int8[row]; - const uint16_t value = ToNetworkInt16(val); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT16: { - const uint16_t value = ToNetworkInt16( - array_view->children[col]->buffer_views[1].data.as_int16[row]); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT32: { - const uint32_t value = ToNetworkInt32( - array_view->children[col]->buffer_views[1].data.as_int32[row]); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT64: { - const int64_t value = ToNetworkInt64( - array_view->children[col]->buffer_views[1].data.as_int64[row]); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_FLOAT: { - const uint32_t value = ToNetworkFloat4( - array_view->children[col]->buffer_views[1].data.as_float[row]); - std::memcpy(param_values[col], &value, sizeof(uint32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DOUBLE: { - const uint64_t value = ToNetworkFloat8( - array_view->children[col]->buffer_views[1].data.as_double[row]); - std::memcpy(param_values[col], &value, sizeof(uint64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - case ArrowType::NANOARROW_TYPE_BINARY: { - const ArrowBufferView view = - ArrowArrayViewGetBytesUnsafe(array_view->children[col], row); - // TODO: overflow check? - param_lengths[col] = static_cast(view.size_bytes); - param_values[col] = const_cast(view.data.as_char); - break; - } - case ArrowType::NANOARROW_TYPE_DATE32: { - // 2000-01-01 - constexpr int32_t kPostgresDateEpoch = 10957; - const int32_t raw_value = - array_view->children[col]->buffer_views[1].data.as_int32[row]; - if (raw_value < INT32_MIN + kPostgresDateEpoch) { - SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1, - "('", bind_schema->children[col]->name, "') Row #", row + 1, - "has value which exceeds postgres date limits"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_TIMESTAMP: { - int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row]; - - bool overflow_safe = true; - - auto unit = bind_schema_fields[col].time_unit; - - switch (unit) { - case NANOARROW_TIME_UNIT_SECOND: - overflow_safe = - val <= kMaxSafeSecondsToMicros && val >= kMinSafeSecondsToMicros; - if (overflow_safe) { - val *= 1000000; - } - - break; - case NANOARROW_TIME_UNIT_MILLI: - overflow_safe = - val <= kMaxSafeMillisToMicros && val >= kMinSafeMillisToMicros; - if (overflow_safe) { - val *= 1000; - } - break; - case NANOARROW_TIME_UNIT_MICRO: - break; - case NANOARROW_TIME_UNIT_NANO: - val /= 1000; - break; - } - - if (!overflow_safe) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 - "' which exceeds PostgreSQL timestamp limits", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (val < (std::numeric_limits::min)() + kPostgresTimestampEpoch) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 "' which would underflow", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) { - const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - } else if (bind_schema_fields[col].type == - ArrowType::NANOARROW_TYPE_DURATION) { - // postgres stores an interval as a 64 bit offset in microsecond - // resolution alongside a 32 bit day and 32 bit month - // for now we just send 0 for the day / month values - const uint64_t value = ToNetworkInt64(val); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t)); - } - break; - } - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { - struct ArrowInterval interval; - ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); - ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval); - - const uint32_t months = ToNetworkInt32(interval.months); - const uint32_t days = ToNetworkInt32(interval.days); - const uint64_t ms = ToNetworkInt64(interval.ns / 1000); - - std::memcpy(param_values[col], &ms, sizeof(uint64_t)); - std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t)); - std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t), - &months, sizeof(uint32_t)); - break; - } - default: - SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('", - bind_schema->children[col]->name, - "') has unsupported type for ingestion ", - ArrowTypeString(bind_schema_fields[col].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - } - - result = PQexecPrepared(conn, /*stmtName=*/"", - /*nParams=*/bind_schema->n_children, param_values.data(), - param_lengths.data(), param_formats.data(), - /*resultFormat=*/0 /*text*/); - - ExecStatusType pg_status = PQresultStatus(result); - if (pg_status != PGRES_COMMAND_OK) { - AdbcStatusCode code = SetError( - error, result, "[libpq] Failed to execute prepared statement: %s %s", - PQresStatus(pg_status), PQerrorMessage(conn)); - PQclear(result); - return code; - } - - PQclear(result); - } - if (rows_affected) *rows_affected += array->length; - - if (has_tz_field) { - std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; - PGresult* reset_tz_result = PQexec(conn, reset_query.c_str()); - if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, reset_tz_result, "[libpq] Failed to reset time zone: %s", - PQerrorMessage(conn)); - PQclear(reset_tz_result); - return code; - } - PQclear(reset_tz_result); - - PGresult* commit_result = PQexec(conn, "COMMIT"); - if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, commit_result, "[libpq] Failed to commit transaction: %s", - PQerrorMessage(conn)); - PQclear(commit_result); - return code; - } - PQclear(commit_result); - } - } - return ADBC_STATUS_OK; - } - - AdbcStatusCode ExecuteCopy(PGconn* conn, int64_t* rows_affected, - struct AdbcError* error) { - if (rows_affected) *rows_affected = 0; - - PostgresCopyStreamWriter writer; - CHECK_NA(INTERNAL, writer.Init(&bind_schema.value), error); - CHECK_NA(INTERNAL, writer.InitFieldWriters(nullptr), error); - - CHECK_NA(INTERNAL, writer.WriteHeader(nullptr), error); - - while (true) { - Handle array; - int res = bind->get_next(&bind.value, &array.value); - if (res != 0) { - SetError(error, - "[libpq] Failed to read next batch from stream of bind parameters: " - "(%d) %s %s", - res, std::strerror(res), bind->get_last_error(&bind.value)); - return ADBC_STATUS_IO; - } - if (!array->release) break; - - CHECK_NA(INTERNAL, writer.SetArray(&array.value), error); - - // build writer buffer - int write_result; - do { - write_result = writer.WriteRecord(nullptr); - } while (write_result == NANOARROW_OK); - - // check if not ENODATA at exit - if (write_result != ENODATA) { - SetError(error, "Error occurred writing COPY data: %s", PQerrorMessage(conn)); - return ADBC_STATUS_IO; - } - - ArrowBuffer buffer = writer.WriteBuffer(); - if (PQputCopyData(conn, reinterpret_cast(buffer.data), buffer.size_bytes) <= - 0) { - SetError(error, "Error writing tuple field data: %s", PQerrorMessage(conn)); - return ADBC_STATUS_IO; - } - - if (rows_affected) *rows_affected += array->length; - writer.Rewind(); - } - - if (PQputCopyEnd(conn, NULL) <= 0) { - SetError(error, "Error message returned by PQputCopyEnd: %s", PQerrorMessage(conn)); - return ADBC_STATUS_IO; - } - - PGresult* result = PQgetResult(conn); - ExecStatusType pg_status = PQresultStatus(result); - if (pg_status != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, result, "[libpq] Failed to execute COPY statement: %s %s", - PQresStatus(pg_status), PQerrorMessage(conn)); - PQclear(result); - return code; - } - - PQclear(result); - return ADBC_STATUS_OK; - } -}; -} // namespace - int TupleReader::GetSchema(struct ArrowSchema* out) { assert(copy_reader_ != nullptr); + ArrowErrorInit(&na_error_); int na_res = copy_reader_->GetSchema(out); if (out->release == nullptr) { @@ -654,75 +67,74 @@ int TupleReader::GetSchema(struct ArrowSchema* out) { return na_res; } -int TupleReader::InitQueryAndFetchFirst(struct ArrowError* error) { - // Fetch + parse the header - int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0); - data_.size_bytes = get_copy_res; - data_.data.as_char = pgbuf_; +int TupleReader::GetCopyData() { + if (pgbuf_ != nullptr) { + PQfreemem(pgbuf_); + pgbuf_ = nullptr; + } + data_.size_bytes = 0; + data_.data.as_char = nullptr; + + int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0); if (get_copy_res == -2) { - SetError(&error_, "[libpq] Fetch header failed: %s", PQerrorMessage(conn_)); + SetError(&error_, "[libpq] PQgetCopyData() failed: %s", PQerrorMessage(conn_)); status_ = ADBC_STATUS_IO; return AdbcStatusCodeToErrno(status_); } - int na_res = copy_reader_->ReadHeader(&data_, error); - if (na_res != NANOARROW_OK) { - SetError(&error_, "[libpq] ReadHeader failed: %s", error->message); - status_ = ADBC_STATUS_IO; - return AdbcStatusCodeToErrno(status_); + if (get_copy_res == -1) { + // Check the server-side response + PQclear(result_); + result_ = PQgetResult(conn_); + const ExecStatusType pq_status = PQresultStatus(result_); + if (pq_status != PGRES_COMMAND_OK) { + status_ = SetError(&error_, result_, "[libpq] Execution error [%s]: %s", + PQresStatus(pq_status), PQresultErrorMessage(result_)); + return AdbcStatusCodeToErrno(status_); + } else { + return ENODATA; + } } + data_.size_bytes = get_copy_res; + data_.data.as_char = pgbuf_; return NANOARROW_OK; } -int TupleReader::AppendRowAndFetchNext(struct ArrowError* error) { +int TupleReader::AppendRowAndFetchNext() { // Parse the result (the header AND the first row are included in the first // call to PQgetCopyData()) - int na_res = copy_reader_->ReadRecord(&data_, error); + int na_res = copy_reader_->ReadRecord(&data_, &na_error_); if (na_res != NANOARROW_OK && na_res != ENODATA) { SetError(&error_, "[libpq] ReadRecord failed at row %" PRId64 ": %s", row_id_, - error->message); + na_error_.message); status_ = ADBC_STATUS_IO; return na_res; } row_id_++; - // Fetch + check - PQfreemem(pgbuf_); - pgbuf_ = nullptr; - int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0); - data_.size_bytes = get_copy_res; - data_.data.as_char = pgbuf_; - - if (get_copy_res == -2) { - SetError(&error_, "[libpq] PQgetCopyData failed at row %" PRId64 ": %s", row_id_, - PQerrorMessage(conn_)); - status_ = ADBC_STATUS_IO; - return AdbcStatusCodeToErrno(status_); - } else if (get_copy_res == -1) { - // Returned when COPY has finished successfully - return ENODATA; - } else if ((copy_reader_->array_size_approx_bytes() + get_copy_res) >= - batch_size_hint_bytes_) { + NANOARROW_RETURN_NOT_OK(GetCopyData()); + if ((copy_reader_->array_size_approx_bytes() + data_.size_bytes) >= + batch_size_hint_bytes_) { // Appending the next row will result in an array larger than requested. // Return EOVERFLOW to force GetNext() to build the current result and return. return EOVERFLOW; - } else { - return NANOARROW_OK; } + + return NANOARROW_OK; } -int TupleReader::BuildOutput(struct ArrowArray* out, struct ArrowError* error) { +int TupleReader::BuildOutput(struct ArrowArray* out) { if (copy_reader_->array_size_approx_bytes() == 0) { out->release = nullptr; return NANOARROW_OK; } - int na_res = copy_reader_->GetArray(out, error); + int na_res = copy_reader_->GetArray(out, &na_error_); if (na_res != NANOARROW_OK) { - SetError(&error_, "[libpq] Failed to build result array: %s", error->message); + SetError(&error_, "[libpq] Failed to build result array: %s", na_error_.message); status_ = ADBC_STATUS_INTERNAL; return na_res; } @@ -736,22 +148,35 @@ int TupleReader::GetNext(struct ArrowArray* out) { return 0; } - struct ArrowError error; - error.message[0] = '\0'; + int na_res; + ArrowErrorInit(&na_error_); if (row_id_ == -1) { - NANOARROW_RETURN_NOT_OK(InitQueryAndFetchFirst(&error)); + na_res = GetCopyData(); + if (na_res == ENODATA) { + is_finished_ = true; + out->release = nullptr; + return 0; + } else if (na_res != NANOARROW_OK) { + return na_res; + } + + na_res = copy_reader_->ReadHeader(&data_, &na_error_); + if (na_res != NANOARROW_OK) { + SetError(&error_, "[libpq] ReadHeader() failed: %s", na_error_.message); + return na_res; + } + row_id_++; } - int na_res; do { - na_res = AppendRowAndFetchNext(&error); + na_res = AppendRowAndFetchNext(); if (na_res == EOVERFLOW) { // The result would be too big to return if we appended the row. When EOVERFLOW is // returned, the copy reader leaves the output in a valid state. The data is left in // pg_buf_/data_ and will attempt to be appended on the next call to GetNext() - return BuildOutput(out, &error); + return BuildOutput(out); } } while (na_res == NANOARROW_OK); @@ -764,31 +189,7 @@ int TupleReader::GetNext(struct ArrowArray* out) { // Finish the result properly and return the last result. Note that BuildOutput() may // set tmp.release = nullptr if there were zero rows in the copy reader (can // occur in an overflow scenario). - struct ArrowArray tmp; - NANOARROW_RETURN_NOT_OK(BuildOutput(&tmp, &error)); - - PQclear(result_); - // Check the server-side response - result_ = PQgetResult(conn_); - const ExecStatusType pq_status = PQresultStatus(result_); - if (pq_status != PGRES_COMMAND_OK) { - const char* sqlstate = PQresultErrorField(result_, PG_DIAG_SQLSTATE); - SetError(&error_, result_, "[libpq] Query failed [%s]: %s", PQresStatus(pq_status), - PQresultErrorMessage(result_)); - - if (tmp.release != nullptr) { - tmp.release(&tmp); - } - - if (sqlstate != nullptr && std::strcmp(sqlstate, "57014") == 0) { - status_ = ADBC_STATUS_CANCELLED; - } else { - status_ = ADBC_STATUS_IO; - } - return AdbcStatusCodeToErrno(status_); - } - - ArrowArrayMove(&tmp, out); + NANOARROW_RETURN_NOT_OK(BuildOutput(out)); return NANOARROW_OK; } @@ -896,13 +297,7 @@ AdbcStatusCode PostgresStatement::Bind(struct ArrowArray* values, if (bind_.release) bind_.release(&bind_); // Make a one-value stream - bind_.private_data = new OneValueStream{*schema, *values}; - bind_.get_schema = &OneValueStream::GetSchema; - bind_.get_next = &OneValueStream::GetNext; - bind_.get_last_error = &OneValueStream::GetLastError; - bind_.release = &OneValueStream::Release; - std::memset(values, 0, sizeof(*values)); - std::memset(schema, 0, sizeof(*schema)); + adbc::driver::MakeArrayStream(schema, values, &bind_); return ADBC_STATUS_OK; } @@ -924,11 +319,11 @@ AdbcStatusCode PostgresStatement::Cancel(struct AdbcError* error) { return connection_->Cancel(error); } -AdbcStatusCode PostgresStatement::CreateBulkTable( - const std::string& current_schema, const struct ArrowSchema& source_schema, - const std::vector& source_schema_fields, - std::string* escaped_table, std::string* escaped_field_list, - struct AdbcError* error) { +AdbcStatusCode PostgresStatement::CreateBulkTable(const std::string& current_schema, + const struct ArrowSchema& source_schema, + std::string* escaped_table, + std::string* escaped_field_list, + struct AdbcError* error) { PGconn* conn = connection_->conn(); if (!ingest_.db_schema.empty() && ingest_.temporary) { @@ -1011,7 +406,7 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( create += *escaped_table; create += " ("; - for (size_t i = 0; i < source_schema_fields.size(); i++) { + for (int64_t i = 0; i < source_schema.n_children; i++) { if (i > 0) { create += ", "; *escaped_field_list += ", "; @@ -1028,82 +423,13 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( *escaped_field_list += escaped; PQfreemem(escaped); - switch (source_schema_fields[i].type) { - case ArrowType::NANOARROW_TYPE_BOOL: - create += " BOOLEAN"; - break; - case ArrowType::NANOARROW_TYPE_INT8: - case ArrowType::NANOARROW_TYPE_INT16: - create += " SMALLINT"; - break; - case ArrowType::NANOARROW_TYPE_INT32: - create += " INTEGER"; - break; - case ArrowType::NANOARROW_TYPE_INT64: - create += " BIGINT"; - break; - case ArrowType::NANOARROW_TYPE_FLOAT: - create += " REAL"; - break; - case ArrowType::NANOARROW_TYPE_DOUBLE: - create += " DOUBLE PRECISION"; - break; - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - create += " TEXT"; - break; - case ArrowType::NANOARROW_TYPE_BINARY: - create += " BYTEA"; - break; - case ArrowType::NANOARROW_TYPE_DATE32: - create += " DATE"; - break; - case ArrowType::NANOARROW_TYPE_TIMESTAMP: - if (strcmp("", source_schema_fields[i].timezone)) { - create += " TIMESTAMPTZ"; - } else { - create += " TIMESTAMP"; - } - break; - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - create += " INTERVAL"; - break; - case ArrowType::NANOARROW_TYPE_DECIMAL128: - case ArrowType::NANOARROW_TYPE_DECIMAL256: - create += " DECIMAL"; - break; - case ArrowType::NANOARROW_TYPE_DICTIONARY: { - struct ArrowSchemaView value_view; - CHECK_NA(INTERNAL, - ArrowSchemaViewInit(&value_view, source_schema.children[i]->dictionary, - nullptr), - error); - switch (value_view.type) { - case NANOARROW_TYPE_BINARY: - case NANOARROW_TYPE_LARGE_BINARY: - create += " BYTEA"; - break; - case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_STRING: - create += " TEXT"; - break; - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", source_schema.children[i]->name, - "') has unsupported dictionary value type for ingestion ", - ArrowTypeString(value_view.type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - break; - } - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", source_schema.children[i]->name, - "') has unsupported type for ingestion ", - ArrowTypeString(source_schema_fields[i].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } + PostgresType pg_type; + struct ArrowError na_error; + CHECK_NA_DETAIL(INTERNAL, + PostgresType::FromSchema(*type_resolver_, source_schema.children[i], + &pg_type, &na_error), + &na_error, error); + create += " " + pg_type.sql_type_name(); } if (ingest_.mode == IngestMode::kAppend) { @@ -1127,29 +453,14 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( return ADBC_STATUS_OK; } -AdbcStatusCode PostgresStatement::ExecutePreparedStatement( - struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error) { - if (!bind_.release) { - // TODO: set an empty stream just to unify the code paths - SetError(error, "%s", - "[libpq] Prepared statements without parameters are not implemented"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - if (stream) { - // TODO: - SetError(error, "%s", - "[libpq] Prepared statements returning result sets are not implemented"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - BindStream bind_stream(std::move(bind_)); - std::memset(&bind_, 0, sizeof(bind_)); - - RAISE_ADBC(bind_stream.Begin([&]() { return ADBC_STATUS_OK; }, error)); - RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error)); - RAISE_ADBC( - bind_stream.Prepare(connection_->conn(), query_, error, connection_->autocommit())); - RAISE_ADBC(bind_stream.Execute(connection_->conn(), rows_affected, error)); +AdbcStatusCode PostgresStatement::ExecuteBind(struct ArrowArrayStream* stream, + int64_t* rows_affected, + struct AdbcError* error) { + PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); + reader.SetAutocommit(connection_->autocommit()); + reader.SetBind(&bind_); + reader.SetVendorName(connection_->VendorName()); + RAISE_STATUS(error, reader.ToArrayStream(rows_affected, stream)); return ADBC_STATUS_OK; } @@ -1157,27 +468,10 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error) { ClearResult(); - if (prepared_) { - if (bind_.release || !stream) { - return ExecutePreparedStatement(stream, rows_affected, error); - } - // XXX: don't use a prepared statement to execute a no-parameter - // result-set-returning query for now, since we can't easily get - // access to COPY there. (This might have to become sequential - // executions of COPY (EXECUTE ($n, ...)) TO STDOUT which won't - // get all the benefits of a prepared statement.) At preparation - // time we don't know whether the query will be used with a result - // set or not without analyzing the query (we could prepare both?) - // and https://stackoverflow.com/questions/69233792 suggests that - // you can't PREPARE a query containing COPY. - } - if (!stream && !ingest_.target.empty()) { - return ExecuteUpdateBulk(rows_affected, error); - } - // Remove trailing semicolon(s) from the query before feeding it into COPY - while (!query_.empty() && query_.back() == ';') { - query_.pop_back(); + // Use a dedicated path to handle bulk ingest + if (!ingest_.target.empty()) { + return ExecuteIngest(stream, rows_affected, error); } if (query_.empty()) { @@ -1185,53 +479,57 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, return ADBC_STATUS_INVALID_STATE; } - // 1. Prepare the query to get the schema - { - RAISE_ADBC(SetupReader(error)); - - // If the caller did not request a result set or if there are no - // inferred output columns (e.g. a CREATE or UPDATE), then don't - // use COPY (which would fail anyways) - if (!stream || reader_.copy_reader_->pg_type().n_children() == 0) { - RAISE_ADBC(ExecuteUpdateQuery(rows_affected, error)); - if (stream) { - struct ArrowSchema schema; - std::memset(&schema, 0, sizeof(schema)); - RAISE_NA(reader_.copy_reader_->GetSchema(&schema)); - nanoarrow::EmptyArrayStream::MakeUnique(&schema).move(stream); - } - return ADBC_STATUS_OK; - } + // Use a dedicated path to handle parameter binding + if (bind_.release != nullptr) { + return ExecuteBind(stream, rows_affected, error); + } - // This resolves the reader specific to each PostgresType -> ArrowSchema - // conversion. It is unlikely that this will fail given that we have just - // inferred these conversions ourselves. - struct ArrowError na_error; - int na_res = reader_.copy_reader_->InitFieldReaders(&na_error); - if (na_res != NANOARROW_OK) { - SetError(error, "[libpq] Failed to initialize field readers: %s", na_error.message); - return na_res; - } + // If we have been requested to avoid COPY or there is no output requested, + // execute using the PqResultArrayReader. + if (!stream || !UseCopy()) { + PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); + reader.SetVendorName(connection_->VendorName()); + RAISE_STATUS(error, reader.ToArrayStream(rows_affected, stream)); + return ADBC_STATUS_OK; } - // 2. Execute the query with COPY to get binary tuples - { - std::string copy_query = "COPY (" + query_ + ") TO STDOUT (FORMAT binary)"; - reader_.result_ = - PQexecParams(connection_->conn(), copy_query.c_str(), /*nParams=*/0, - /*paramTypes=*/nullptr, /*paramValues=*/nullptr, - /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, kPgBinaryFormat); - if (PQresultStatus(reader_.result_) != PGRES_COPY_OUT) { - AdbcStatusCode code = SetError( - error, reader_.result_, - "[libpq] Failed to execute query: could not begin COPY: %s\nQuery was: %s", - PQerrorMessage(connection_->conn()), copy_query.c_str()); - ClearResult(); - return code; - } - // Result is read from the connection, not the result, but we won't clear it here + PqResultHelper helper(connection_->conn(), query_); + RAISE_STATUS(error, helper.Prepare()); + RAISE_STATUS(error, helper.DescribePrepared()); + + // Initialize the copy reader and infer the output schema (i.e., error for + // unsupported types before issuing the COPY query). This could be lazier + // (i.e., executed on the first call to GetSchema() or GetNext()). + PostgresType root_type; + RAISE_STATUS(error, helper.ResolveOutputTypes(*type_resolver_, &root_type)); + + // If there will be no columns in the result, we can also avoid COPY + if (root_type.n_children() == 0) { + // Could/should move the helper into the reader instead of repreparing + PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); + reader.SetVendorName(connection_->VendorName()); + RAISE_STATUS(error, reader.ToArrayStream(rows_affected, stream)); + return ADBC_STATUS_OK; } + struct ArrowError na_error; + reader_.copy_reader_ = std::make_unique(); + CHECK_NA(INTERNAL, reader_.copy_reader_->Init(root_type), error); + CHECK_NA_DETAIL(INTERNAL, + reader_.copy_reader_->InferOutputSchema( + std::string(connection_->VendorName()), &na_error), + &na_error, error); + + CHECK_NA_DETAIL(INTERNAL, reader_.copy_reader_->InitFieldReaders(&na_error), &na_error, + error); + + // Execute the COPY query + RAISE_STATUS(error, helper.ExecuteCopy()); + + // We need the PQresult back for the reader + reader_.result_ = helper.ReleaseResult(); + + // Export to stream reader_.ExportTo(stream); if (rows_affected) *rows_affected = -1; return ADBC_STATUS_OK; @@ -1243,51 +541,93 @@ AdbcStatusCode PostgresStatement::ExecuteSchema(struct ArrowSchema* schema, if (query_.empty()) { SetError(error, "%s", "[libpq] Must SetSqlQuery before ExecuteQuery"); return ADBC_STATUS_INVALID_STATE; - } else if (bind_.release) { - // TODO: if we have parameters, bind them (since they can affect the output schema) - SetError(error, "[libpq] ExecuteSchema with parameters is not implemented"); - return ADBC_STATUS_NOT_IMPLEMENTED; } - RAISE_ADBC(SetupReader(error)); - CHECK_NA(INTERNAL, reader_.copy_reader_->GetSchema(schema), error); + PqResultHelper helper(connection_->conn(), query_); + + if (bind_.release) { + nanoarrow::UniqueSchema param_schema; + struct ArrowError na_error; + ArrowErrorInit(&na_error); + CHECK_NA_DETAIL(INTERNAL, + ArrowArrayStreamGetSchema(&bind_, param_schema.get(), &na_error), + &na_error, error); + + if (std::string(param_schema->format) != "+s") { + SetError(error, "%s", "[libpq] Bind parameters must have type STRUCT"); + return ADBC_STATUS_INVALID_STATE; + } + + std::vector param_oids(param_schema->n_children); + for (int64_t i = 0; i < param_schema->n_children; i++) { + PostgresType pg_type; + CHECK_NA_DETAIL(INTERNAL, + PostgresType::FromSchema(*type_resolver_, param_schema->children[i], + &pg_type, &na_error), + &na_error, error); + param_oids[i] = pg_type.oid(); + } + + RAISE_STATUS(error, helper.Prepare(param_oids)); + } else { + RAISE_STATUS(error, helper.Prepare()); + } + + RAISE_STATUS(error, helper.DescribePrepared()); + + PostgresType output_type; + RAISE_STATUS(error, helper.ResolveOutputTypes(*type_resolver_, &output_type)); + + nanoarrow::UniqueSchema tmp; + ArrowSchemaInit(tmp.get()); + CHECK_NA(INTERNAL, + output_type.SetSchema(tmp.get(), std::string(connection_->VendorName())), + error); + + tmp.move(schema); return ADBC_STATUS_OK; } -AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, - struct AdbcError* error) { +AdbcStatusCode PostgresStatement::ExecuteIngest(struct ArrowArrayStream* stream, + int64_t* rows_affected, + struct AdbcError* error) { if (!bind_.release) { SetError(error, "%s", "[libpq] Must Bind() before Execute() for bulk ingestion"); return ADBC_STATUS_INVALID_STATE; } + if (stream != nullptr) { + SetError(error, "%s", "[libpq] Bulk ingest with result set is not supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + // Need the current schema to avoid being shadowed by temp tables // This is a little unfortunate; we need another DB roundtrip std::string current_schema; { - PqResultHelper result_helper{connection_->conn(), "SELECT CURRENT_SCHEMA", {}, error}; - RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); + PqResultHelper result_helper{connection_->conn(), "SELECT CURRENT_SCHEMA()"}; + RAISE_STATUS(error, result_helper.Execute()); auto it = result_helper.begin(); if (it == result_helper.end()) { - SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + SetError(error, + "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA()'"); return ADBC_STATUS_INTERNAL; } current_schema = (*it)[0].data; } - BindStream bind_stream(std::move(bind_)); + BindStream bind_stream; + bind_stream.SetBind(&bind_); std::memset(&bind_, 0, sizeof(bind_)); std::string escaped_table; std::string escaped_field_list; - RAISE_ADBC(bind_stream.Begin( - [&]() -> AdbcStatusCode { - return CreateBulkTable(current_schema, bind_stream.bind_schema.value, - bind_stream.bind_schema_fields, &escaped_table, - &escaped_field_list, error); - }, - error)); - RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error)); + RAISE_STATUS(error, bind_stream.Begin([&]() -> Status { + struct AdbcError tmp_error = ADBC_ERROR_INIT; + AdbcStatusCode status_code = + CreateBulkTable(current_schema, bind_stream.bind_schema.value, &escaped_table, + &escaped_field_list, &tmp_error); + return Status::FromAdbc(status_code, tmp_error); + })); std::string query = "COPY "; query += escaped_table; @@ -1304,38 +644,9 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, } PQclear(result); - RAISE_ADBC(bind_stream.ExecuteCopy(connection_->conn(), rows_affected, error)); - return ADBC_STATUS_OK; -} - -AdbcStatusCode PostgresStatement::ExecuteUpdateQuery(int64_t* rows_affected, - struct AdbcError* error) { - // NOTE: must prepare first (used in ExecuteQuery) - PGresult* result = - PQexecPrepared(connection_->conn(), /*stmtName=*/"", /*nParams=*/0, - /*paramValues=*/nullptr, /*paramLengths=*/nullptr, - /*paramFormats=*/nullptr, /*resultFormat=*/kPgBinaryFormat); - ExecStatusType status = PQresultStatus(result); - if (status != PGRES_COMMAND_OK && status != PGRES_TUPLES_OK) { - AdbcStatusCode code = - SetError(error, result, "[libpq] Failed to execute query: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return code; - } - if (rows_affected) { - if (status == PGRES_TUPLES_OK) { - *rows_affected = PQntuples(reader_.result_); - } else { - // In theory, PQcmdTuples would work here, but experimentally it gives - // an empty string even for a DELETE. (Also, why does it return a - // string...) Possibly, it doesn't work because we use PQexecPrepared - // but the docstring is careful to specify it works on an EXECUTE of a - // prepared statement. - *rows_affected = -1; - } - } - PQclear(result); + RAISE_STATUS(error, + bind_stream.ExecuteCopy(connection_->conn(), *connection_->type_resolver(), + rows_affected)); return ADBC_STATUS_OK; } @@ -1363,6 +674,12 @@ AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t } } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { result = std::to_string(reader_.batch_size_hint_bytes_); + } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) { + if (UseCopy()) { + result = "true"; + } else { + result = "false"; + } } else { SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_FOUND; @@ -1482,6 +799,15 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, } this->reader_.batch_size_hint_bytes_ = int_value; + } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) { + if (std::strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { + use_copy_ = true; + } else if (std::strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { + use_copy_ = false; + } else { + SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); + return ADBC_STATUS_INVALID_ARGUMENT; + } } else { SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_IMPLEMENTED; @@ -1516,53 +842,17 @@ AdbcStatusCode PostgresStatement::SetOptionInt(const char* key, int64_t value, return ADBC_STATUS_NOT_IMPLEMENTED; } -AdbcStatusCode PostgresStatement::SetupReader(struct AdbcError* error) { - // TODO: we should pipeline here and assume this will succeed - PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(), - /*nParams=*/0, nullptr); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, result, - "[libpq] Failed to execute query: could not infer schema: failed to " - "prepare query: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return code; - } - PQclear(result); - result = PQdescribePrepared(connection_->conn(), /*stmtName=*/""); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, result, - "[libpq] Failed to execute query: could not infer schema: failed to " - "describe prepared statement: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return code; - } - - // Resolve the information from the PGresult into a PostgresType - PostgresType root_type; - AdbcStatusCode status = ResolvePostgresType(*type_resolver_, result, &root_type, error); - PQclear(result); - if (status != ADBC_STATUS_OK) return status; - - // Initialize the copy reader and infer the output schema (i.e., error for - // unsupported types before issuing the COPY query) - reader_.copy_reader_ = std::make_unique(); - reader_.copy_reader_->Init(root_type); - struct ArrowError na_error; - int na_res = reader_.copy_reader_->InferOutputSchema(&na_error); - if (na_res != NANOARROW_OK) { - SetError(error, "[libpq] Failed to infer output schema: (%d) %s: %s", na_res, - std::strerror(na_res), na_error.message); - return ADBC_STATUS_INTERNAL; - } - return ADBC_STATUS_OK; -} - void PostgresStatement::ClearResult() { // TODO: we may want to synchronize here for safety reader_.Release(); } + +int PostgresStatement::UseCopy() { + if (use_copy_ == -1) { + return connection_->VendorName() != "Redshift"; + } else { + return use_copy_; + } +} + } // namespace adbcpq diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h index 33be8f0c3d..60ada992b0 100644 --- a/c/driver/postgresql/statement.h +++ b/c/driver/postgresql/statement.h @@ -22,7 +22,7 @@ #include #include -#include +#include #include #include @@ -33,6 +33,8 @@ #define ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES \ "adbc.postgresql.batch_size_hint_bytes" +#define ADBC_POSTGRESQL_OPTION_USE_COPY "adbc.postgresql.use_copy" + namespace adbcpq { class PostgresConnection; class PostgresStatement; @@ -50,6 +52,7 @@ class TupleReader final { row_id_(-1), batch_size_hint_bytes_(16777216), is_finished_(false) { + ArrowErrorInit(&na_error_); data_.data.as_char = nullptr; data_.size_bytes = 0; } @@ -66,9 +69,9 @@ class TupleReader final { private: friend class PostgresStatement; - int InitQueryAndFetchFirst(struct ArrowError* error); - int AppendRowAndFetchNext(struct ArrowError* error); - int BuildOutput(struct ArrowArray* out, struct ArrowError* error); + int GetCopyData(); + int AppendRowAndFetchNext(); + int BuildOutput(struct ArrowArray* out); static int GetSchemaTrampoline(struct ArrowArrayStream* self, struct ArrowSchema* out); static int GetNextTrampoline(struct ArrowArrayStream* self, struct ArrowArray* out); @@ -77,6 +80,7 @@ class TupleReader final { AdbcStatusCode status_; struct AdbcError error_; + struct ArrowError na_error_; PGconn* conn_; PGresult* result_; char* pgbuf_; @@ -90,7 +94,11 @@ class TupleReader final { class PostgresStatement { public: PostgresStatement() - : connection_(nullptr), query_(), prepared_(false), reader_(nullptr) { + : connection_(nullptr), + query_(), + prepared_(false), + use_copy_(-1), + reader_(nullptr) { std::memset(&bind_, 0, sizeof(bind_)); } @@ -125,17 +133,15 @@ class PostgresStatement { // Helper methods void ClearResult(); - AdbcStatusCode CreateBulkTable( - const std::string& current_schema, const struct ArrowSchema& source_schema, - const std::vector& source_schema_fields, - std::string* escaped_table, std::string* escaped_field_list, - struct AdbcError* error); - AdbcStatusCode ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error); - AdbcStatusCode ExecuteUpdateQuery(int64_t* rows_affected, struct AdbcError* error); - AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream, - int64_t* rows_affected, - struct AdbcError* error); - AdbcStatusCode SetupReader(struct AdbcError* error); + AdbcStatusCode CreateBulkTable(const std::string& current_schema, + const struct ArrowSchema& source_schema, + std::string* escaped_table, + std::string* escaped_field_list, + struct AdbcError* error); + AdbcStatusCode ExecuteIngest(struct ArrowArrayStream* stream, int64_t* rows_affected, + struct AdbcError* error); + AdbcStatusCode ExecuteBind(struct ArrowArrayStream* stream, int64_t* rows_affected, + struct AdbcError* error); private: std::shared_ptr type_resolver_; @@ -154,6 +160,9 @@ class PostgresStatement { kCreateAppend, }; + // Options + int use_copy_; + struct { std::string db_schema; std::string target; @@ -162,5 +171,7 @@ class PostgresStatement { } ingest_; TupleReader reader_; + + int UseCopy(); }; } // namespace adbcpq diff --git a/c/driver/snowflake/CMakeLists.txt b/c/driver/snowflake/CMakeLists.txt index 3f05dfa042..1d3874b41f 100644 --- a/c/driver/snowflake/CMakeLists.txt +++ b/c/driver/snowflake/CMakeLists.txt @@ -35,7 +35,8 @@ add_go_lib("${REPOSITORY_ROOT}/go/adbc/pkg/snowflake/" foreach(LIB_TARGET ${ADBC_LIBRARIES}) target_include_directories(${LIB_TARGET} SYSTEM - INTERFACE ${REPOSITORY_ROOT} ${REPOSITORY_ROOT}/c/ + INTERFACE ${REPOSITORY_ROOT}/c/ + ${REPOSITORY_ROOT}/c/include/ ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver) endforeach() @@ -61,8 +62,8 @@ if(ADBC_BUILD_TESTS) ${TEST_LINK_LIBS}) target_compile_features(adbc-driver-snowflake-test PRIVATE cxx_std_17) target_include_directories(adbc-driver-snowflake-test SYSTEM - PRIVATE ${REPOSITORY_ROOT} - ${REPOSITORY_ROOT}/c/ + PRIVATE ${REPOSITORY_ROOT}/c/ + ${REPOSITORY_ROOT}/c/include/ ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver ${REPOSITORY_ROOT}/c/driver/common) diff --git a/c/driver/snowflake/meson.build b/c/driver/snowflake/meson.build new file mode 100644 index 0000000000..20a7d3c70e --- /dev/null +++ b/c/driver/snowflake/meson.build @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +golang = find_program('go') + +if build_machine.system() == 'windows' + prefix = '' + suffix = '.lib' +elif build_machine.system() == 'darwin' + prefix = 'lib' + suffix = '.dylib' +else + prefix = 'lib' + suffix = '.so' +endif + +adbc_driver_snowflake_name = prefix + 'adbc_driver_snowflake' + suffix +adbc_driver_snowflake_lib = custom_target( + 'adbc_driver_snowflake', + output: adbc_driver_snowflake_name, + command : [ + golang, + 'build', + '-C', + meson.project_source_root() + '/../go/adbc/pkg/snowflake', + '-tags=driverlib', + '-buildmode=c-shared', + '-o', + meson.current_build_dir() + '/' + adbc_driver_snowflake_name, + ], + install : true, + install_dir : '.', +) + +pkg.generate( + name: 'Apache Arrow Database Connectivity (ADBC) Snowflake driver', + description: 'The ADBC Snowflake driver provides an ADBC driver for Snowflake.', + url: 'https://github.com/apache/arrow-adbc', + libraries: [adbc_driver_snowflake_lib], + filebase: 'adbc-driver-snowflake', +) + +if get_option('tests') + exc = executable( + 'adbc-driver-snowflake-test', + 'snowflake_test.cc', + include_directories: [include_dir, c_dir, driver_dir], + link_with: [ + adbc_common_lib, + adbc_driver_snowflake_lib, + ], + dependencies: [adbc_validation_dep], + ) + test('adbc-driver-snowflake', exc) +endif diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index a4d742491a..262286192a 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include +#include #include #include #include @@ -99,7 +99,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { adbc_validation::Handle statement; CHECK_OK(AdbcStatementNew(connection, &statement.value, error)); - std::string create = "CREATE TABLE \""; + std::string create = "CREATE OR REPLACE TABLE \""; create += name; create += "\" (int64s INT, strings TEXT)"; CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), error)); @@ -131,7 +131,13 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { return NANOARROW_TYPE_DOUBLE; case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_LIST: + case NANOARROW_TYPE_LARGE_LIST: return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + return NANOARROW_TYPE_BINARY; default: return ingest_type; } @@ -149,7 +155,11 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { bool supports_dynamic_parameter_binding() const override { return true; } bool supports_error_on_incompatible_schema() const override { return false; } bool ddl_implicit_commit_txn() const override { return true; } + bool supports_ingest_view_types() const override { return false; } + bool supports_ingest_float16() const override { return false; } + std::string db_schema() const override { return schema_; } + std::string catalog() const override { return "ADBC_TESTING"; } const char* uri_; bool skip_{false}; diff --git a/c/driver/sqlite/CMakeLists.txt b/c/driver/sqlite/CMakeLists.txt index 3cfdd32bbf..d0c45b7433 100644 --- a/c/driver/sqlite/CMakeLists.txt +++ b/c/driver/sqlite/CMakeLists.txt @@ -64,8 +64,8 @@ foreach(LIB_TARGET ${ADBC_LIBRARIES}) target_compile_definitions(${LIB_TARGET} PRIVATE ADBC_EXPORTING ${ADBC_SQLITE_COMPILE_DEFINES}) target_include_directories(${LIB_TARGET} SYSTEM - PRIVATE ${REPOSITORY_ROOT} - ${REPOSITORY_ROOT}/c/ + PRIVATE ${REPOSITORY_ROOT}/c/ + ${REPOSITORY_ROOT}/c/include/ ${SQLite3_INCLUDE_DIRS} ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver) @@ -98,8 +98,8 @@ if(ADBC_BUILD_TESTS) PRIVATE ${ADBC_SQLITE_COMPILE_DEFINES}) target_compile_features(adbc-driver-sqlite-test PRIVATE cxx_std_17) target_include_directories(adbc-driver-sqlite-test SYSTEM - PRIVATE ${REPOSITORY_ROOT} - ${REPOSITORY_ROOT}/c/ + PRIVATE ${REPOSITORY_ROOT}/c/ + ${REPOSITORY_ROOT}/c/include/ ${LIBPQ_INCLUDE_DIRS} ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver) diff --git a/c/driver/sqlite/adbc-driver-sqlite.pc.in b/c/driver/sqlite/adbc-driver-sqlite.pc.in index 359e37b50c..157344e8fe 100644 --- a/c/driver/sqlite/adbc-driver-sqlite.pc.in +++ b/c/driver/sqlite/adbc-driver-sqlite.pc.in @@ -21,6 +21,7 @@ libdir=@ADBC_PKG_CONFIG_LIBDIR@ Name: Apache Arrow Database Connectivity (ADBC) SQLite driver Description: The ADBC SQLite driver provides an ADBC driver for SQLite. +URL: https://github.com/apache/arrow-adbc Version: @ADBC_VERSION@ Libs: -L${libdir} -ladbc_driver_sqlite Cflags: -I${includedir} diff --git a/c/driver/sqlite/meson.build b/c/driver/sqlite/meson.build new file mode 100644 index 0000000000..ad61f7e435 --- /dev/null +++ b/c/driver/sqlite/meson.build @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +sqlite3_dep = dependency('sqlite3') + +time_t_size = meson.get_compiler('c').sizeof( + 'time_t', + prefix : '#include ', +) + +adbc_sqlite3_driver_lib = library( + 'adbc_driver_sqlite', + sources: [ + 'sqlite.cc', + 'statement_reader.c', + ], + include_directories: [include_dir, c_dir], + link_with: [adbc_common_lib, adbc_framework_lib], + dependencies: [nanoarrow_dep, fmt_dep, sqlite3_dep], + c_args: ['-DSIZEOF_TIME_T=' + time_t_size.to_string()], +) + +pkg.generate( + name: 'Apache Arrow Database Connectivity (ADBC) SQLite driver', + description: 'The ADBC SQLite driver provides an ADBC driver for SQLite.', + url: 'https://github.com/apache/arrow-adbc', + libraries: [adbc_sqlite3_driver_lib], + filebase: 'adbc-driver-sqlite', +) + +if get_option('tests') + exc = executable( + 'adbc-driver-sqlite-test', + sources: ['sqlite_test.cc'], + include_directories: [include_dir, c_dir, driver_dir], + link_with: [ + adbc_common_lib, + adbc_sqlite3_driver_lib, + ], + dependencies: [sqlite3_dep, adbc_validation_dep], + ) + test('adbc-driver-sqlite', exc) +endif diff --git a/c/driver/sqlite/sqlite.cc b/c/driver/sqlite/sqlite.cc index f7ed45ae3a..6348b5ce31 100644 --- a/c/driver/sqlite/sqlite.cc +++ b/c/driver/sqlite/sqlite.cc @@ -15,21 +15,17 @@ // specific language governing permissions and limitations // under the License. -#include #include -#include -#include +#include #include #include -#include "driver/common/options.h" -#include "driver/common/utils.h" -#include "driver/framework/base_connection.h" -#include "driver/framework/base_database.h" +#define ADBC_FRAMEWORK_USE_FMT #include "driver/framework/base_driver.h" -#include "driver/framework/base_statement.h" -#include "driver/framework/catalog.h" +#include "driver/framework/connection.h" +#include "driver/framework/database.h" +#include "driver/framework/statement.h" #include "driver/framework/status.h" #include "driver/sqlite/statement_reader.h" @@ -112,7 +108,7 @@ class SqliteStringBuilder { } else if (rc == SQLITE_TOOBIG) { return status::Internal("query too long"); } else if (rc != SQLITE_OK) { - return status::Internal("unknown SQLite error ({})", rc); + return status::fmt::Internal("unknown SQLite error ({})", rc); } len = sqlite3_str_length(str_); result_ = sqlite3_str_finish(str_); @@ -142,7 +138,7 @@ class SqliteQuery { Result Next() { if (!stmt_) { - return status::Internal( + return status::fmt::Internal( "query already finished or never initialized\nquery was: {}", query_); } int rc = sqlite3_step(stmt_); @@ -154,17 +150,17 @@ class SqliteQuery { return Close(rc); } - Status Close(int rc) { + Status Close(int last_rc) { if (stmt_) { int rc = sqlite3_finalize(stmt_); stmt_ = nullptr; if (rc != SQLITE_OK && rc != SQLITE_DONE) { - return status::Internal("failed to execute: {}\nquery was: {}", - sqlite3_errmsg(conn_), query_); + return status::fmt::Internal("failed to execute: {}\nquery was: {}", + sqlite3_errmsg(conn_), query_); } - } else if (rc != SQLITE_OK) { - return status::Internal("failed to execute: {}\nquery was: {}", - sqlite3_errmsg(conn_), query_); + } else if (last_rc != SQLITE_OK) { + return status::fmt::Internal("failed to execute: {}\nquery was: {}", + sqlite3_errmsg(conn_), query_); } return status::Ok(); } @@ -196,7 +192,7 @@ class SqliteQuery { UNWRAP_RESULT(bool has_row, q.Next()); if (!has_row) break; - int rc = std::forward(row_func)(q.stmt_); + rc = std::forward(row_func)(q.stmt_); if (rc != SQLITE_OK) break; } return q.Close(); @@ -222,10 +218,6 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { std::string query = "SELECT DISTINCT name FROM pragma_database_list() WHERE name LIKE ?"; - this->table_filter = table_filter; - this->column_filter = column_filter; - this->table_types = table_types; - UNWRAP_STATUS(SqliteQuery::Scan( conn, query, [&](sqlite3_stmt* stmt) { @@ -249,14 +241,17 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { return status::Ok(); } - Status LoadCatalogs() override { return status::Ok(); }; + Status LoadCatalogs(std::optional catalog_filter) override { + return status::Ok(); + }; Result> NextCatalog() override { if (next_catalog >= catalogs.size()) return std::nullopt; return catalogs[next_catalog++]; } - Status LoadSchemas(std::string_view catalog) override { + Status LoadSchemas(std::string_view catalog, + std::optional schema_filter) override { next_schema = 0; return status::Ok(); }; @@ -266,7 +261,9 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { return schemas[next_schema++]; } - Status LoadTables(std::string_view catalog, std::string_view schema) override { + Status LoadTables(std::string_view catalog, std::string_view schema, + std::optional table_filter, + const std::vector& table_types) override { next_table = 0; tables.clear(); if (!schema.empty()) return status::Ok(); @@ -309,7 +306,8 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { } Status LoadColumns(std::string_view catalog, std::string_view schema, - std::string_view table) override { + std::string_view table, + std::optional column_filter) override { // XXX: pragma_table_info doesn't appear to work with bind parameters // XXX: because we're saving the SqliteQuery, we also need to save the string builder columns_query.Reset(); @@ -486,9 +484,6 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { }; sqlite3* conn = nullptr; - std::optional table_filter; - std::optional column_filter; - std::vector table_types; std::vector catalogs; std::vector schemas; std::vector> tables; @@ -501,7 +496,7 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { size_t next_constraint = 0; }; -class SqliteDatabase : public driver::DatabaseBase { +class SqliteDatabase : public driver::Database { public: [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[SQLite]"; @@ -513,9 +508,9 @@ class SqliteDatabase : public driver::DatabaseBase { if (rc != SQLITE_OK) { Status status; if (conn_) { - status = status::IO("failed to open '{}': {}", uri_, sqlite3_errmsg(conn)); + status = status::fmt::IO("failed to open '{}': {}", uri_, sqlite3_errmsg(conn)); } else { - status = status::IO("failed to open '{}': failed to allocate memory", uri_); + status = status::fmt::IO("failed to open '{}': failed to allocate memory", uri_); } (void)sqlite3_close(conn); return status; @@ -532,8 +527,8 @@ class SqliteDatabase : public driver::DatabaseBase { if (conn_) { int rc = sqlite3_close_v2(conn_); if (rc != SQLITE_OK) { - return status::IO("failed to close connection: ({}) {}", rc, - sqlite3_errmsg(conn_)); + return status::fmt::IO("failed to close connection: ({}) {}", rc, + sqlite3_errmsg(conn_)); } conn_ = nullptr; } @@ -557,7 +552,7 @@ class SqliteDatabase : public driver::DatabaseBase { sqlite3* conn_ = nullptr; }; -class SqliteConnection : public driver::ConnectionBase { +class SqliteConnection : public driver::Connection { public: [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[SQLite]"; @@ -594,7 +589,7 @@ class SqliteConnection : public driver::ConnectionBase { /*pzTail=*/nullptr); if (rc != SQLITE_OK) { (void)sqlite3_finalize(stmt); - return status::NotFound("GetTableSchema: {}", sqlite3_errmsg(conn_)); + return status::fmt::NotFound("GetTableSchema: {}", sqlite3_errmsg(conn_)); } nanoarrow::UniqueArrayStream stream; @@ -606,7 +601,8 @@ class SqliteConnection : public driver::ConnectionBase { int code = stream->get_schema(stream.get(), schema); if (code != 0) { (void)sqlite3_finalize(stmt); - return status::IO("failed to get schema: ({}) {}", code, std::strerror(code)); + return status::fmt::IO("failed to get schema: ({}) {}", code, + std::strerror(code)); } } (void)sqlite3_finalize(stmt); @@ -665,12 +661,12 @@ class SqliteConnection : public driver::ConnectionBase { if (conn_) { int rc = sqlite3_close_v2(conn_); if (rc != SQLITE_OK) { - return status::IO("failed to close connection: ({}) {}", rc, - sqlite3_errmsg(conn_)); + return status::fmt::IO("failed to close connection: ({}) {}", rc, + sqlite3_errmsg(conn_)); } conn_ = nullptr; } - return ConnectionBase::ReleaseImpl(); + return Connection::ReleaseImpl(); } Status RollbackImpl() { @@ -689,7 +685,8 @@ class SqliteConnection : public driver::ConnectionBase { int rc = sqlite3_db_config(conn_, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, enabled ? 1 : 0, nullptr); if (rc != SQLITE_OK) { - return status::IO("cannot enable extension loading: {}", sqlite3_errmsg(conn_)); + return status::fmt::IO("cannot enable extension loading: {}", + sqlite3_errmsg(conn_)); } return status::Ok(); } else if (key == kConnectionOptionLoadExtensionPath) { @@ -702,9 +699,9 @@ class SqliteConnection : public driver::ConnectionBase { } else if (key == kConnectionOptionLoadExtensionEntrypoint) { #if !defined(ADBC_SQLITE_WITH_NO_LOAD_EXTENSION) if (extension_path_.empty()) { - return status::InvalidState("{} can only be set after {}", - kConnectionOptionLoadExtensionEntrypoint, - kConnectionOptionLoadExtensionPath); + return status::fmt::InvalidState("{} can only be set after {}", + kConnectionOptionLoadExtensionEntrypoint, + kConnectionOptionLoadExtensionPath); } const char* extension_entrypoint = nullptr; if (value.has_value()) { @@ -716,7 +713,7 @@ class SqliteConnection : public driver::ConnectionBase { int rc = sqlite3_load_extension(conn_, extension_path_.c_str(), extension_entrypoint, &message); if (rc != SQLITE_OK) { - auto status = status::Unknown( + auto status = status::fmt::Unknown( "failed to load extension {} (entrypoint {}): {}", extension_path_, extension_entrypoint ? extension_entrypoint : "(NULL)", message ? message : "(unknown error)"); @@ -757,7 +754,7 @@ class SqliteConnection : public driver::ConnectionBase { std::string extension_path_; }; -class SqliteStatement : public driver::StatementBase { +class SqliteStatement : public driver::Statement { public: [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[SQLite]"; @@ -782,15 +779,15 @@ class SqliteStatement : public driver::StatementBase { // Parameter validation if (state.target_catalog && state.temporary) { - return status::InvalidState("{} Cannot set both {} and {}", kErrorPrefix, - ADBC_INGEST_OPTION_TARGET_CATALOG, - ADBC_INGEST_OPTION_TEMPORARY); + return status::fmt::InvalidState("{} Cannot set both {} and {}", kErrorPrefix, + ADBC_INGEST_OPTION_TARGET_CATALOG, + ADBC_INGEST_OPTION_TEMPORARY); } else if (state.target_schema) { - return status::NotImplemented("{} {} not supported", kErrorPrefix, - ADBC_INGEST_OPTION_TARGET_DB_SCHEMA); + return status::fmt::NotImplemented("{} {} not supported", kErrorPrefix, + ADBC_INGEST_OPTION_TARGET_DB_SCHEMA); } else if (!state.target_table) { - return status::InvalidState("{} Must set {}", kErrorPrefix, - ADBC_INGEST_OPTION_TARGET_TABLE); + return status::fmt::InvalidState("{} Must set {}", kErrorPrefix, + ADBC_INGEST_OPTION_TARGET_TABLE); } // Create statements for creating the table, inserting a row, and the table name @@ -844,8 +841,9 @@ class SqliteStatement : public driver::StatementBase { int status = ArrowSchemaViewInit(&view, binder_.schema.children[i], &arrow_error); if (status != 0) { - return status::Internal("failed to parse schema for column {}: {} ({}): {}", i, - std::strerror(status), status, arrow_error.message); + return status::fmt::Internal("failed to parse schema for column {}: {} ({}): {}", + i, std::strerror(status), status, + arrow_error.message); } switch (view.type) { @@ -924,27 +922,31 @@ class SqliteStatement : public driver::StatementBase { &stmt, /*pzTail=*/nullptr); if (rc != SQLITE_OK) { std::ignore = sqlite3_finalize(stmt); - return status::Internal("failed to prepare: {}\nquery was: {}", - sqlite3_errmsg(conn_), insert); + return status::fmt::Internal("failed to prepare: {}\nquery was: {}", + sqlite3_errmsg(conn_), insert); } } assert(stmt != nullptr); - AdbcStatusCode status = ADBC_STATUS_OK; + AdbcStatusCode status_code = ADBC_STATUS_OK; + Status status = status::Ok(); struct AdbcError error = ADBC_ERROR_INIT; while (true) { char finished = 0; - status = AdbcSqliteBinderBindNext(&binder_, conn_, stmt, &finished, &error); - if (status != ADBC_STATUS_OK || finished) break; + status_code = AdbcSqliteBinderBindNext(&binder_, conn_, stmt, &finished, &error); + if (status_code != ADBC_STATUS_OK || finished) { + status = Status::FromAdbc(status_code, error); + break; + } int rc = 0; do { rc = sqlite3_step(stmt); } while (rc == SQLITE_ROW); if (rc != SQLITE_DONE) { - SetError(&error, "failed to execute: %s\nquery was: %s", sqlite3_errmsg(conn_), - insert.data()); - status = ADBC_STATUS_INTERNAL; + status = status::fmt::Internal("failed to execute: {}\nquery was: {}", + sqlite3_errmsg(conn_), insert.data()); + status_code = ADBC_STATUS_INTERNAL; break; } row_count++; @@ -952,15 +954,15 @@ class SqliteStatement : public driver::StatementBase { std::ignore = sqlite3_finalize(stmt); if (is_autocommit) { - if (status == ADBC_STATUS_OK) { + if (status_code == ADBC_STATUS_OK) { UNWRAP_STATUS(::adbc::sqlite::SqliteQuery::Execute(conn_, "COMMIT")); } else { UNWRAP_STATUS(::adbc::sqlite::SqliteQuery::Execute(conn_, "ROLLBACK")); } } - if (status != ADBC_STATUS_OK) { - return Status::FromAdbc(status, error); + if (status_code != ADBC_STATUS_OK) { + return status; } return row_count; } @@ -972,8 +974,8 @@ class SqliteStatement : public driver::StatementBase { const int64_t expected = sqlite3_bind_parameter_count(stmt_); const int64_t actual = binder_.schema.n_children; if (actual != expected) { - return status::InvalidState("parameter count mismatch: expected {} but found {}", - expected, actual); + return status::fmt::InvalidState( + "parameter count mismatch: expected {} but found {}", expected, actual); } auto status = @@ -1000,11 +1002,12 @@ class SqliteStatement : public driver::StatementBase { const int64_t expected = sqlite3_bind_parameter_count(stmt_); const int64_t actual = binder_.schema.n_children; if (actual != expected) { - return status::InvalidState("parameter count mismatch: expected {} but found {}", - expected, actual); + return status::fmt::InvalidState( + "parameter count mismatch: expected {} but found {}", expected, actual); } - int64_t rows = 0; + int64_t output_rows = 0; + int64_t changed_rows = 0; SqliteMutexGuard guard(conn_); @@ -1023,7 +1026,11 @@ class SqliteStatement : public driver::StatementBase { } while (sqlite3_step(stmt_) == SQLITE_ROW) { - rows++; + output_rows++; + } + + if (sqlite3_column_count(stmt_) == 0) { + changed_rows += sqlite3_changes(conn_); } if (!binder_.schema.release) break; @@ -1032,13 +1039,15 @@ class SqliteStatement : public driver::StatementBase { if (sqlite3_reset(stmt_) != SQLITE_OK) { const char* msg = sqlite3_errmsg(conn_); - return status::IO("failed to execute query: {}", msg ? msg : "(unknown error)"); + return status::fmt::IO("failed to execute query: {}", + msg ? msg : "(unknown error)"); } if (sqlite3_column_count(stmt_) == 0) { - rows = sqlite3_changes(conn_); + return changed_rows; + } else { + return output_rows; } - return rows; } Result ExecuteUpdateImpl(PreparedState& state) { return ExecuteUpdateImpl(); } @@ -1052,8 +1061,8 @@ class SqliteStatement : public driver::StatementBase { int num_params = sqlite3_bind_parameter_count(stmt_); if (num_params < 0) { // Should not happen - return status::Internal("{} SQLite returned negative parameter count", - kErrorPrefix); + return status::fmt::Internal("{} SQLite returned negative parameter count", + kErrorPrefix); } nanoarrow::UniqueSchema uschema; @@ -1078,7 +1087,7 @@ class SqliteStatement : public driver::StatementBase { Status InitImpl(void* parent) { conn_ = reinterpret_cast(parent)->conn(); - return StatementBase::InitImpl(parent); + return Statement::InitImpl(parent); } Status PrepareImpl(QueryState& state) { @@ -1086,8 +1095,8 @@ class SqliteStatement : public driver::StatementBase { int rc = sqlite3_finalize(stmt_); stmt_ = nullptr; if (rc != SQLITE_OK) { - return status::IO("{} Failed to finalize previous statement: ({}) {}", - kErrorPrefix, rc, sqlite3_errmsg(conn_)); + return status::fmt::IO("{} Failed to finalize previous statement: ({}) {}", + kErrorPrefix, rc, sqlite3_errmsg(conn_)); } } @@ -1098,8 +1107,8 @@ class SqliteStatement : public driver::StatementBase { std::string msg = sqlite3_errmsg(conn_); std::ignore = sqlite3_finalize(stmt_); stmt_ = NULL; - return status::InvalidArgument("{} Failed to prepare query: {}\nquery: {}", - kErrorPrefix, msg, state.query); + return status::fmt::InvalidArgument("{} Failed to prepare query: {}\nquery: {}", + kErrorPrefix, msg, state.query); } return status::Ok(); } @@ -1109,22 +1118,22 @@ class SqliteStatement : public driver::StatementBase { int rc = sqlite3_finalize(stmt_); stmt_ = nullptr; if (rc != SQLITE_OK) { - return status::IO("{} Failed to finalize statement: ({}) {}", kErrorPrefix, rc, - sqlite3_errmsg(conn_)); + return status::fmt::IO("{} Failed to finalize statement: ({}) {}", kErrorPrefix, + rc, sqlite3_errmsg(conn_)); } } AdbcSqliteBinderRelease(&binder_); - return StatementBase::ReleaseImpl(); + return Statement::ReleaseImpl(); } Status SetOptionImpl(std::string_view key, driver::Option value) { if (key == kStatementOptionBatchRows) { UNWRAP_RESULT(int64_t batch_size, value.AsInt()); if (batch_size >= std::numeric_limits::max() || batch_size <= 0) { - return status::InvalidArgument( + return status::fmt::InvalidArgument( "{} Invalid statement option value {}={} (value is non-positive or out of " "range of int)", - kErrorPrefix, key, value); + kErrorPrefix, key, value.Format()); } batch_size_ = static_cast(batch_size); return status::Ok(); diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc index bfb432b153..8ceb747aca 100644 --- a/c/driver/sqlite/sqlite_test.cc +++ b/c/driver/sqlite/sqlite_test.cc @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include #include @@ -79,10 +79,16 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { case NANOARROW_TYPE_UINT32: case NANOARROW_TYPE_UINT64: return NANOARROW_TYPE_INT64; + case NANOARROW_TYPE_HALF_FLOAT: case NANOARROW_TYPE_FLOAT: - case NANOARROW_TYPE_DOUBLE: return NANOARROW_TYPE_DOUBLE; case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: + return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: + return NANOARROW_TYPE_BINARY; case NANOARROW_TYPE_DATE32: case NANOARROW_TYPE_TIMESTAMP: return NANOARROW_TYPE_STRING; @@ -328,6 +334,12 @@ class SqliteStatementTest : public ::testing::Test, void TestSqlIngestInterval() { GTEST_SKIP() << "Cannot ingest Interval (not implemented)"; } + void TestSqlIngestListOfInt32() { + GTEST_SKIP() << "Cannot ingest list (not implemented)"; + } + void TestSqlIngestListOfString() { + GTEST_SKIP() << "Cannot ingest list (not implemented)"; + } protected: void ValidateIngestedTemporalData(struct ArrowArrayView* values, ArrowType type, @@ -439,8 +451,11 @@ class SqliteReaderTest : public ::testing::Test { } void Bind(struct ArrowArray* batch, struct ArrowSchema* schema) { - ASSERT_THAT(AdbcSqliteBinderSetArray(&binder, batch, schema, &error), - IsOkStatus(&error)); + Handle stream; + struct ArrowArray batch_internal = *batch; + batch->release = nullptr; + adbc_validation::MakeStream(&stream.value, schema, {batch_internal}); + ASSERT_NO_FATAL_FAILURE(Bind(&stream.value)); } void Bind(struct ArrowArrayStream* stream) { diff --git a/c/driver/sqlite/statement_reader.c b/c/driver/sqlite/statement_reader.c index 39edce69fa..69f90ebd68 100644 --- a/c/driver/sqlite/statement_reader.c +++ b/c/driver/sqlite/statement_reader.c @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +#if !defined(_WIN32) +#define _POSIX_C_SOURCE 200112L +#endif + #include "statement_reader.h" #include @@ -24,7 +28,7 @@ #include #include -#include +#include #include #include @@ -85,8 +89,11 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder, switch (value_view.type) { case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: case NANOARROW_TYPE_BINARY: case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: break; default: SetError(error, "Column %d dictionary has unsupported type %s", i, @@ -101,14 +108,6 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder, return ADBC_STATUS_OK; } -AdbcStatusCode AdbcSqliteBinderSetArray(struct AdbcSqliteBinder* binder, - struct ArrowArray* values, - struct ArrowSchema* schema, - struct AdbcError* error) { - AdbcSqliteBinderRelease(binder); - RAISE_ADBC(BatchToArrayStream(values, schema, &binder->params, error)); - return AdbcSqliteBinderSet(binder, error); -} // NOLINT(whitespace/indent) AdbcStatusCode AdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* binder, struct ArrowArrayStream* values, struct AdbcError* error) { @@ -330,11 +329,13 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 } else { switch (binder->types[col]) { case NANOARROW_TYPE_BINARY: - case NANOARROW_TYPE_LARGE_BINARY: { + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: { struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(binder->batch.children[col], binder->next_row); - status = sqlite3_bind_blob(stmt, col + 1, value.data.as_char, value.size_bytes, - SQLITE_STATIC); + status = sqlite3_bind_blob(stmt, col + 1, value.data.as_char, + (int)value.size_bytes, SQLITE_STATIC); break; } case NANOARROW_TYPE_BOOL: @@ -363,6 +364,7 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 status = sqlite3_bind_int64(stmt, col + 1, value); break; } + case NANOARROW_TYPE_HALF_FLOAT: case NANOARROW_TYPE_FLOAT: case NANOARROW_TYPE_DOUBLE: { double value = ArrowArrayViewGetDoubleUnsafe(binder->batch.children[col], @@ -371,11 +373,12 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 break; } case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_STRING: { + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: { struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(binder->batch.children[col], binder->next_row); - status = sqlite3_bind_text(stmt, col + 1, value.data.as_char, value.size_bytes, - SQLITE_STATIC); + status = sqlite3_bind_text(stmt, col + 1, value.data.as_char, + (int)value.size_bytes, SQLITE_STATIC); break; } case NANOARROW_TYPE_DICTIONARY: { @@ -388,7 +391,7 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe( binder->batch.children[col]->dictionary, value_index); status = sqlite3_bind_text(stmt, col + 1, value.data.as_char, - value.size_bytes, SQLITE_STATIC); + (int)value.size_bytes, SQLITE_STATIC); } break; } @@ -408,16 +411,16 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 RAISE_ADBC(ArrowDate32ToIsoString((int32_t)value, &tsstr, error)); // SQLITE_TRANSIENT ensures the value is copied during bind - status = - sqlite3_bind_text(stmt, col + 1, tsstr, strlen(tsstr), SQLITE_TRANSIENT); + status = sqlite3_bind_text(stmt, col + 1, tsstr, (int)strlen(tsstr), + SQLITE_TRANSIENT); free(tsstr); break; } case NANOARROW_TYPE_TIMESTAMP: { struct ArrowSchemaView bind_schema_view; - RAISE_ADBC(ArrowSchemaViewInit(&bind_schema_view, binder->schema.children[col], - &arrow_error)); + RAISE_NA(ArrowSchemaViewInit(&bind_schema_view, binder->schema.children[col], + &arrow_error)); enum ArrowTimeUnit unit = bind_schema_view.time_unit; int64_t value = ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row); @@ -426,8 +429,8 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 RAISE_ADBC(ArrowTimestampToIsoString(value, unit, &tsstr, error)); // SQLITE_TRANSIENT ensures the value is copied during bind - status = - sqlite3_bind_text(stmt, col + 1, tsstr, strlen(tsstr), SQLITE_TRANSIENT); + status = sqlite3_bind_text(stmt, col + 1, tsstr, (int)strlen(tsstr), + SQLITE_TRANSIENT); free((char*)tsstr); break; } @@ -625,9 +628,18 @@ int StatementReaderGetNext(struct ArrowArrayStream* self, struct ArrowArray* out struct StatementReader* reader = (struct StatementReader*)self->private_data; if (reader->initial_batch.release != NULL) { - memcpy(out, &reader->initial_batch, sizeof(*out)); - memset(&reader->initial_batch, 0, sizeof(reader->initial_batch)); - return 0; + // Canonically return zero-row results as a stream with zero batches + if (reader->initial_batch.length == 0) { + reader->initial_batch.release(&reader->initial_batch); + reader->done = true; + + out->release = NULL; + return 0; + } else { + memcpy(out, &reader->initial_batch, sizeof(*out)); + memset(&reader->initial_batch, 0, sizeof(reader->initial_batch)); + return 0; + } } else if (reader->done) { out->release = NULL; return 0; @@ -771,7 +783,7 @@ AdbcStatusCode StatementReaderInitializeInfer(int num_columns, size_t infer_rows CHECK_NA(INTERNAL, ArrowBitmapReserve(&validity[i], infer_rows), error); ArrowBufferInit(&data[i]); CHECK_NA(INTERNAL, ArrowBufferReserve(&data[i], infer_rows * sizeof(int64_t)), error); - memset(&binary[i], 0, sizeof(struct ArrowBuffer)); + ArrowBufferInit(&binary[i]); current_type[i] = NANOARROW_TYPE_INT64; } return ADBC_STATUS_OK; @@ -832,7 +844,7 @@ AdbcStatusCode StatementReaderUpcastInt64ToDouble(struct ArrowBuffer* data, size_t num_elements = data->size_bytes / sizeof(int64_t); const int64_t* elements = (const int64_t*)data->data; for (size_t i = 0; i < num_elements; i++) { - double value = elements[i]; + double value = (double)elements[i]; ArrowBufferAppendUnsafe(&doubles, &value, sizeof(double)); } ArrowBufferReset(data); @@ -1121,7 +1133,7 @@ AdbcStatusCode AdbcSqliteExportReader(sqlite3* db, sqlite3_stmt* stmt, memset(reader, 0, sizeof(struct StatementReader)); reader->db = db; reader->stmt = stmt; - reader->batch_size = batch_size; + reader->batch_size = (int)batch_size; stream->private_data = reader; stream->release = StatementReaderRelease; diff --git a/c/driver/sqlite/statement_reader.h b/c/driver/sqlite/statement_reader.h index 63a222f05f..2e6b19086c 100644 --- a/c/driver/sqlite/statement_reader.h +++ b/c/driver/sqlite/statement_reader.h @@ -19,7 +19,7 @@ #pragma once -#include +#include #include #include @@ -40,11 +40,6 @@ struct ADBC_EXPORT AdbcSqliteBinder { int64_t next_row; }; -ADBC_EXPORT -AdbcStatusCode AdbcSqliteBinderSetArray(struct AdbcSqliteBinder* binder, - struct ArrowArray* values, - struct ArrowSchema* schema, - struct AdbcError* error); ADBC_EXPORT AdbcStatusCode AdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* binder, struct ArrowArrayStream* values, diff --git a/c/driver_manager/CMakeLists.txt b/c/driver_manager/CMakeLists.txt index 6fb51d9a6a..0eb17f0c8d 100644 --- a/c/driver_manager/CMakeLists.txt +++ b/c/driver_manager/CMakeLists.txt @@ -30,17 +30,19 @@ add_arrow_lib(adbc_driver_manager ${CMAKE_DL_LIBS} SHARED_LINK_FLAGS ${ADBC_LINK_FLAGS}) -include_directories(SYSTEM ${REPOSITORY_ROOT}) include_directories(SYSTEM ${REPOSITORY_ROOT}/c/) +include_directories(SYSTEM ${REPOSITORY_ROOT}/c/include/) include_directories(SYSTEM ${REPOSITORY_ROOT}/c/vendor) include_directories(SYSTEM ${REPOSITORY_ROOT}/c/driver) +install(FILES "${REPOSITORY_ROOT}/c/include/adbc.h" DESTINATION include) +install(FILES "${REPOSITORY_ROOT}/c/include/arrow-adbc/adbc.h" + DESTINATION include/arrow-adbc) + foreach(LIB_TARGET ${ADBC_LIBRARIES}) target_compile_definitions(${LIB_TARGET} PRIVATE ADBC_EXPORTING) endforeach() -install(FILES "${REPOSITORY_ROOT}/adbc.h" adbc_driver_manager.h DESTINATION include) - if(ADBC_BUILD_TESTS) if(ADBC_TEST_LINKAGE STREQUAL "shared") set(TEST_LINK_LIBS adbc_driver_manager_shared) diff --git a/c/driver_manager/adbc-driver-manager.pc.in b/c/driver_manager/adbc-driver-manager.pc.in index 17b290e666..c20430566d 100644 --- a/c/driver_manager/adbc-driver-manager.pc.in +++ b/c/driver_manager/adbc-driver-manager.pc.in @@ -21,6 +21,7 @@ libdir=@ADBC_PKG_CONFIG_LIBDIR@ Name: Apache Arrow Database Connectivity (ADBC) driver manager Description: ADBC driver manager provides API to use ADBC driver. +URL: https://github.com/apache/arrow-adbc Version: @ADBC_VERSION@ Libs: -L${libdir} -ladbc_driver_manager Cflags: -I${includedir} diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index d8340c544b..0ce173a888 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -#include "adbc_driver_manager.h" -#include +#include "arrow-adbc/adbc_driver_manager.h" +#include "arrow-adbc/adbc.h" #include #include @@ -84,6 +84,36 @@ void SetError(struct AdbcError* error, const std::string& message) { error->release = ReleaseError; } +// Copies src_error into error and releases src_error +void SetError(struct AdbcError* error, struct AdbcError* src_error) { + if (!error) return; + if (error->release) error->release(error); + + if (src_error->message) { + size_t message_size = strlen(src_error->message); + error->message = new char[message_size]; + std::memcpy(error->message, src_error->message, message_size); + error->message[message_size] = '\0'; + } else { + error->message = nullptr; + } + + error->release = ReleaseError; + if (src_error->release) { + src_error->release(src_error); + } +} + +struct OwnedError { + struct AdbcError error = ADBC_ERROR_INIT; + + ~OwnedError() { + if (error.release) { + error.release(&error); + } + } +}; + // Driver state /// A driver DLL. @@ -666,7 +696,7 @@ std::string AdbcDriverManagerDefaultEntrypoint(const std::string& driver) { int AdbcErrorGetDetailCount(const struct AdbcError* error) { if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && - error->private_driver) { + error->private_driver && error->private_driver->ErrorGetDetailCount) { return error->private_driver->ErrorGetDetailCount(error); } return 0; @@ -674,7 +704,7 @@ int AdbcErrorGetDetailCount(const struct AdbcError* error) { struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && - error->private_driver) { + error->private_driver && error->private_driver->ErrorGetDetail) { return error->private_driver->ErrorGetDetail(error, index); } return {nullptr, nullptr, 0}; @@ -900,6 +930,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, database->private_driver, error); } + if (status != ADBC_STATUS_OK) { // Restore private_data so it will be released by AdbcDatabaseRelease database->private_data = args; @@ -910,10 +941,18 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_driver = nullptr; return status; } - status = database->private_driver->DatabaseNew(database, error); + + // Errors that occur during AdbcDatabaseXXX() refer to the driver via + // the private_driver member; however, after we return we have released + // the driver and inspecting the error might segfault. Here, we scope + // the driver-produced error to this function and make a copy if necessary. + OwnedError driver_error; + + status = database->private_driver->DatabaseNew(database, &driver_error.error); if (status != ADBC_STATUS_OK) { if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); + SetError(error, &driver_error.error); + database->private_driver->release(database->private_driver, nullptr); } delete database->private_driver; database->private_driver = nullptr; @@ -927,33 +966,34 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* INIT_ERROR(error, database); for (const auto& option : options) { - status = database->private_driver->DatabaseSetOption(database, option.first.c_str(), - option.second.c_str(), error); + status = database->private_driver->DatabaseSetOption( + database, option.first.c_str(), option.second.c_str(), &driver_error.error); if (status != ADBC_STATUS_OK) break; } for (const auto& option : bytes_options) { status = database->private_driver->DatabaseSetOptionBytes( database, option.first.c_str(), reinterpret_cast(option.second.data()), option.second.size(), - error); + &driver_error.error); if (status != ADBC_STATUS_OK) break; } for (const auto& option : int_options) { status = database->private_driver->DatabaseSetOptionInt( - database, option.first.c_str(), option.second, error); + database, option.first.c_str(), option.second, &driver_error.error); if (status != ADBC_STATUS_OK) break; } for (const auto& option : double_options) { status = database->private_driver->DatabaseSetOptionDouble( - database, option.first.c_str(), option.second, error); + database, option.first.c_str(), option.second, &driver_error.error); if (status != ADBC_STATUS_OK) break; } if (status != ADBC_STATUS_OK) { // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); + std::ignore = database->private_driver->DatabaseRelease(database, nullptr); if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); + SetError(error, &driver_error.error); + database->private_driver->release(database->private_driver, nullptr); } delete database->private_driver; database->private_driver = nullptr; @@ -962,6 +1002,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_data = nullptr; return status; } + return database->private_driver->DatabaseInit(database, error); } diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc index 0eaf644b93..c2342ebae2 100644 --- a/c/driver_manager/adbc_driver_manager_test.cc +++ b/c/driver_manager/adbc_driver_manager_test.cc @@ -22,8 +22,8 @@ #include #include -#include "adbc.h" -#include "adbc_driver_manager.h" +#include "arrow-adbc/adbc.h" +#include "arrow-adbc/adbc_driver_manager.h" #include "validation/adbc_validation.h" #include "validation/adbc_validation_util.h" @@ -187,10 +187,18 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { case NANOARROW_TYPE_UINT32: case NANOARROW_TYPE_UINT64: return NANOARROW_TYPE_INT64; + case NANOARROW_TYPE_HALF_FLOAT: case NANOARROW_TYPE_FLOAT: - case NANOARROW_TYPE_DOUBLE: return NANOARROW_TYPE_DOUBLE; case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: + return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: + return NANOARROW_TYPE_BINARY; + case NANOARROW_TYPE_DATE32: + case NANOARROW_TYPE_TIMESTAMP: return NANOARROW_TYPE_STRING; default: return ingest_type; @@ -267,8 +275,6 @@ class SqliteStatementTest : public ::testing::Test, void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); } void TestSqlIngestUInt64() { GTEST_SKIP() << "Cannot ingest UINT64 (out of range)"; } - void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; } - void TestSqlIngestDate32() { GTEST_SKIP() << "Cannot ingest DATE (not implemented)"; } void TestSqlIngestTimestamp() { GTEST_SKIP() << "Cannot ingest TIMESTAMP (not implemented)"; } @@ -281,6 +287,12 @@ class SqliteStatementTest : public ::testing::Test, void TestSqlIngestInterval() { GTEST_SKIP() << "Cannot ingest Interval (not implemented)"; } + void TestSqlIngestListOfInt32() { + GTEST_SKIP() << "Cannot ingest list (not implemented)"; + } + void TestSqlIngestListOfString() { + GTEST_SKIP() << "Cannot ingest list (not implemented)"; + } protected: SqliteQuirks quirks_; diff --git a/c/driver_manager/adbc_version_100.h b/c/driver_manager/adbc_version_100.h index b349f86f73..95a045db1c 100644 --- a/c/driver_manager/adbc_version_100.h +++ b/c/driver_manager/adbc_version_100.h @@ -17,7 +17,7 @@ // A dummy version 1.0.0 ADBC driver to test compatibility. -#include +#include #ifdef __cplusplus extern "C" { diff --git a/c/driver_manager/adbc_version_100_compatibility_test.cc b/c/driver_manager/adbc_version_100_compatibility_test.cc index 27e5f5d997..43079ecb3e 100644 --- a/c/driver_manager/adbc_version_100_compatibility_test.cc +++ b/c/driver_manager/adbc_version_100_compatibility_test.cc @@ -20,9 +20,9 @@ #include -#include "adbc.h" -#include "adbc_driver_manager.h" #include "adbc_version_100.h" +#include "arrow-adbc/adbc.h" +#include "arrow-adbc/adbc_driver_manager.h" #include "validation/adbc_validation_util.h" namespace adbc { diff --git a/c/driver_manager/meson.build b/c/driver_manager/meson.build new file mode 100644 index 0000000000..6be37f96f4 --- /dev/null +++ b/c/driver_manager/meson.build @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +adbc_driver_manager_lib = library( + 'adbc_driver_manager', + 'adbc_driver_manager.cc', + include_directories: [include_dir], + install: true, +) + +pkg.generate( + name: 'Apache Arrow Database Connectivity (ADBC) driver manager', + description: 'ADBC driver manager provides API to use ADBC driver.', + url: 'https://github.com/apache/arrow-adbc', + libraries: [adbc_driver_manager_lib], + filebase: 'adbc-driver-manager', +) diff --git a/c/include/adbc.h b/c/include/adbc.h new file mode 100644 index 0000000000..e21d1dc836 --- /dev/null +++ b/c/include/adbc.h @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma message( \ + "Including \"adbc.h\" is deprecated. " \ + "Please include \"arrow-adbc/adbc.h\" instead") + +#include "arrow-adbc/adbc.h" diff --git a/c/include/adbc_driver_manager.h b/c/include/adbc_driver_manager.h new file mode 100644 index 0000000000..88f8d4d584 --- /dev/null +++ b/c/include/adbc_driver_manager.h @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma message( \ + "Including \"adbc_driver_manager.h\" is deprecated. " \ + "Please include \"arrow-adbc/adbc_driver_manager.h\" instead") + +#include "arrow-adbc/adbc_driver_manager.h" diff --git a/adbc.h b/c/include/arrow-adbc/adbc.h similarity index 99% rename from adbc.h rename to c/include/arrow-adbc/adbc.h index fe0fc4f70c..b965672e6f 100644 --- a/adbc.h +++ b/c/include/arrow-adbc/adbc.h @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -/// \file adbc.h ADBC: Arrow Database connectivity +/// \file arrow-adbc/adbc.h ADBC: Arrow Database connectivity /// /// An Arrow-based interface between applications and database /// drivers. ADBC aims to provide a vendor-independent API for SQL @@ -1972,7 +1972,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, /// \since ADBC API revision 1.1.0 /// /// \param[in] statement The statement to execute. -/// \param[out] out The result schema. +/// \param[out] schema The result schema. /// \param[out] error An optional location to return an error /// message if necessary. /// diff --git a/go/adbc/drivermgr/adbc_driver_manager.h b/c/include/arrow-adbc/adbc_driver_manager.h similarity index 94% rename from go/adbc/drivermgr/adbc_driver_manager.h rename to c/include/arrow-adbc/adbc_driver_manager.h index 69b767a583..c32368ab69 100644 --- a/go/adbc/drivermgr/adbc_driver_manager.h +++ b/c/include/arrow-adbc/adbc_driver_manager.h @@ -15,9 +15,14 @@ // specific language governing permissions and limitations // under the License. +/// \file arrow-adbc/adbc_driver_manager.h ADBC Driver Manager +/// +/// A helper library to dynamically load and use multiple ADBC drivers in the +/// same process. + #pragma once -#include +#include #ifdef __cplusplus extern "C" { diff --git a/c/integration/duckdb/CMakeLists.txt b/c/integration/duckdb/CMakeLists.txt index 589d9842cd..9065450b0d 100644 --- a/c/integration/duckdb/CMakeLists.txt +++ b/c/integration/duckdb/CMakeLists.txt @@ -68,7 +68,7 @@ if(ADBC_BUILD_TESTS) add_dependencies(adbc-integration-duckdb-test duckdb) target_compile_features(adbc-integration-duckdb-test PRIVATE cxx_std_17) target_include_directories(adbc-integration-duckdb-test SYSTEM - PRIVATE ${REPOSITORY_ROOT} ${REPOSITORY_ROOT}/c/ + PRIVATE ${REPOSITORY_ROOT}/c/ ${REPOSITORY_ROOT}/c/include/ ${REPOSITORY_ROOT}/c/vendor ${REPOSITORY_ROOT}/c/driver) adbc_configure_target(adbc-integration-duckdb-test) diff --git a/c/integration/duckdb/duckdb_test.cc b/c/integration/duckdb/duckdb_test.cc index 37bb03e563..5a8ecaf7b4 100644 --- a/c/integration/duckdb/duckdb_test.cc +++ b/c/integration/duckdb/duckdb_test.cc @@ -17,8 +17,8 @@ #include -#include -#include +#include +#include #include #include "validation/adbc_validation.h" diff --git a/c/meson.build b/c/meson.build new file mode 100644 index 0000000000..b5b9fbce70 --- /dev/null +++ b/c/meson.build @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +project( + 'arrow-adbc', + 'c', 'cpp', + version: '1.4.0-SNAPSHOT', + license: 'Apache-2.0', + meson_version: '>=1.3.0', + default_options: [ + 'buildtype=release', + 'c_std=c99', + 'warning_level=2', + 'cpp_std=c++17', + ] +) + +add_project_arguments('-Wno-int-conversion', '-Wno-unused-parameter', language: 'c') +add_project_arguments('-Wno-unused-parameter', '-Wno-reorder', language: 'cpp') + +c_dir = include_directories('.') +include_dir = include_directories('include') +install_headers('include/adbc.h') +install_headers('include/arrow-adbc/adbc.h', subdir: 'arrow-adbc') +driver_dir = include_directories('driver') +nanoarrow_dep = dependency('nanoarrow') +fmt_dep = dependency('fmt') + +if get_option('tests') + gtest_main_dep = dependency('gtest_main') + gmock_dep = dependency('gmock') +else + gtest_main_dep = disabler() + gmock_dep = disabler() +endif + +needs_driver_manager = get_option('driver_manager') \ + or get_option('tests') + +pkg = import('pkgconfig') + +if needs_driver_manager + install_headers('include/adbc_driver_manager.h') + install_headers('include/arrow-adbc/adbc_driver_manager.h', subdir: 'arrow-adbc') + subdir('driver_manager') +endif + +subdir('driver/common') +subdir('driver/framework') + +if get_option('tests') + subdir('validation') +endif + +if get_option('bigquery') + subdir('driver/bigquery') +endif + +if get_option('flightsql') + subdir('driver/flightsql') +endif + +if get_option('postgresql') + subdir('driver/postgresql') +endif + +if get_option('sqlite') + subdir('driver/sqlite') +endif + +if get_option('snowflake') + subdir('driver/snowflake') +endif diff --git a/c/meson.options b/c/meson.options new file mode 100644 index 0000000000..87d5534495 --- /dev/null +++ b/c/meson.options @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +option('tests', type: 'boolean', description: 'Build tests', value: false) +option('benchmarks', type: 'boolean', description: 'Build benchmarks', value: false) +option( + 'bigquery', + type: 'boolean', + description: 'Build ADBC BigQuery driver', + value: false +) +option( + 'flightsql', + type: 'boolean', + description: 'Build ADBC FlightSQL driver', + value: false +) +option( + 'driver_manager', + type: 'boolean', + description: 'Build ADBC Driver Manager', + value: false +) +option( + 'postgresql', + type: 'boolean', + description: 'Build ADBC PostgreSQL Driver', + value: false +) +option( + 'sqlite', + type: 'boolean', + description: 'Build ADBC SQLite Driver', + value: false +) +option( + 'snowflake', + type: 'boolean', + description: 'Build ADBC Snowflake Driver', + value: false +) diff --git a/c/subprojects/fmt.wrap b/c/subprojects/fmt.wrap new file mode 100644 index 0000000000..4e96460106 --- /dev/null +++ b/c/subprojects/fmt.wrap @@ -0,0 +1,13 @@ +[wrap-file] +directory = fmt-10.2.0 +source_url = https://github.com/fmtlib/fmt/archive/10.2.0.tar.gz +source_filename = fmt-10.2.0.tar.gz +source_hash = 3ca91733a7313a8ad41c0885929415f8ec0a2a31d4dc7e27e9331412f4ca26ac +patch_filename = fmt_10.2.0-2_patch.zip +patch_url = https://wrapdb.mesonbuild.com/v2/fmt_10.2.0-2/get_patch +patch_hash = 2428c3a386a8390c76378f81ef804a297f4edc3b789499dd56629b7902b8ddb7 +source_fallback_url = https://github.com/mesonbuild/wrapdb/releases/download/fmt_10.2.0-2/fmt-10.2.0.tar.gz +wrapdb_version = 10.2.0-2 + +[provide] +fmt = fmt_dep diff --git a/c/subprojects/gtest.wrap b/c/subprojects/gtest.wrap new file mode 100644 index 0000000000..adb8a9a6d9 --- /dev/null +++ b/c/subprojects/gtest.wrap @@ -0,0 +1,16 @@ +[wrap-file] +directory = googletest-1.14.0 +source_url = https://github.com/google/googletest/archive/refs/tags/v1.14.0.tar.gz +source_filename = gtest-1.14.0.tar.gz +source_hash = 8ad598c73ad796e0d8280b082cebd82a630d73e73cd3c70057938a6501bba5d7 +patch_filename = gtest_1.14.0-2_patch.zip +patch_url = https://wrapdb.mesonbuild.com/v2/gtest_1.14.0-2/get_patch +patch_hash = 4ec7f767364386a99f7b2d61678287a73ad6ba0f9998be43b51794c464a63732 +source_fallback_url = https://github.com/mesonbuild/wrapdb/releases/download/gtest_1.14.0-2/gtest-1.14.0.tar.gz +wrapdb_version = 1.14.0-2 + +[provide] +gtest = gtest_dep +gtest_main = gtest_main_dep +gmock = gmock_dep +gmock_main = gmock_main_dep diff --git a/c/subprojects/nanoarrow.wrap b/c/subprojects/nanoarrow.wrap new file mode 100644 index 0000000000..612d1118a6 --- /dev/null +++ b/c/subprojects/nanoarrow.wrap @@ -0,0 +1,8 @@ +[wrap-file] +directory = arrow-nanoarrow-33d2c8b973d8f8f424e02ac92ddeaace2a92f8dd +source_url = https://github.com/apache/arrow-nanoarrow/archive/33d2c8b973d8f8f424e02ac92ddeaace2a92f8dd.tar.gz +source_filename = arrow-nanoarrow-33d2c8b973d8f8f424e02ac92ddeaace2a92f8dd.tar.gz +source_hash = be4d2a6f1467793fe1b02c6ecf12383ed9ecf29557531715a3b9e11578ab18e8 + +[provide] +nanoarrow = nanoarrow_dep diff --git a/c/subprojects/sqlite3.wrap b/c/subprojects/sqlite3.wrap new file mode 100644 index 0000000000..5d2eb5c512 --- /dev/null +++ b/c/subprojects/sqlite3.wrap @@ -0,0 +1,13 @@ +[wrap-file] +directory = sqlite-amalgamation-3460000 +source_url = https://www.sqlite.org/2024/sqlite-amalgamation-3460000.zip +source_filename = sqlite-amalgamation-3460000.zip +source_hash = 712a7d09d2a22652fb06a49af516e051979a3984adb067da86760e60ed51a7f5 +patch_filename = sqlite3_3.46.0-1_patch.zip +patch_url = https://wrapdb.mesonbuild.com/v2/sqlite3_3.46.0-1/get_patch +patch_hash = c6dda193e59e4bd11dbc6f399cae1904d231e0cb119224480bec6c94c5d0e04e +source_fallback_url = https://github.com/mesonbuild/wrapdb/releases/download/sqlite3_3.46.0-1/sqlite-amalgamation-3460000.zip +wrapdb_version = 3.46.0-1 + +[provide] +sqlite3 = sqlite3_dep diff --git a/c/validation/CMakeLists.txt b/c/validation/CMakeLists.txt index eb29e14691..04bc0115aa 100644 --- a/c/validation/CMakeLists.txt +++ b/c/validation/CMakeLists.txt @@ -19,7 +19,8 @@ add_library(adbc_validation_util STATIC adbc_validation_util.cc) adbc_configure_target(adbc_validation_util) target_compile_features(adbc_validation_util PRIVATE cxx_std_17) target_include_directories(adbc_validation_util SYSTEM - PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/driver/" + PRIVATE "${REPOSITORY_ROOT}/c/include/" + "${REPOSITORY_ROOT}/c/driver/" "${REPOSITORY_ROOT}/c/vendor/") target_link_libraries(adbc_validation_util PUBLIC adbc_driver_common nanoarrow GTest::gtest GTest::gmock) @@ -30,7 +31,8 @@ add_library(adbc_validation OBJECT adbc_configure_target(adbc_validation) target_compile_features(adbc_validation PRIVATE cxx_std_17) target_include_directories(adbc_validation SYSTEM - PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/driver/" + PRIVATE "${REPOSITORY_ROOT}/c/include/" + "${REPOSITORY_ROOT}/c/driver/" "${REPOSITORY_ROOT}/c/vendor/") target_link_libraries(adbc_validation PUBLIC adbc_driver_common diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc index 355bfd5086..5cd592679e 100644 --- a/c/validation/adbc_validation.cc +++ b/c/validation/adbc_validation.cc @@ -29,7 +29,7 @@ #include #include -#include +#include #include #include #include diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index abe9a76868..427e39b2e2 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -23,10 +23,12 @@ #include #include -#include +#include #include #include +#include "adbc_validation_util.h" + namespace adbc_validation { #define ADBCV_STRINGIFY(s) #s @@ -160,6 +162,18 @@ class DriverQuirks { return ingest_type; } + /// \brief For a given Arrow type of (possibly nested) ingested data, what Arrow type + /// will the database return when that column is selected? + virtual SchemaField IngestSelectRoundTripType(SchemaField ingest_field) const { + SchemaField out(ingest_field.name, IngestSelectRoundTripType(ingest_field.type), + ingest_field.nullable); + for (const auto& child : ingest_field.children) { + out.children.push_back(IngestSelectRoundTripType(child)); + } + + return out; + } + /// \brief Whether bulk ingest is supported virtual bool supports_bulk_ingest(const char* mode) const { return true; } @@ -224,6 +238,12 @@ class DriverQuirks { /// column matching. virtual bool supports_error_on_incompatible_schema() const { return true; } + /// \brief Whether ingestion supports StringView/BinaryView types + virtual bool supports_ingest_view_types() const { return true; } + + /// \brief Whether ingestion supports Float16 + virtual bool supports_ingest_float16() const { return true; } + /// \brief Default catalog to use for tests virtual std::string catalog() const { return ""; } @@ -344,7 +364,7 @@ class StatementTest { void TestNewInit(); void TestRelease(); - // ---- Type-specific tests -------------------- + // ---- Type-specific ingest tests ------------- void TestSqlIngestBool(); @@ -359,13 +379,18 @@ class StatementTest { void TestSqlIngestUInt64(); // Floats + void TestSqlIngestFloat16(); void TestSqlIngestFloat32(); void TestSqlIngestFloat64(); // Strings void TestSqlIngestString(); void TestSqlIngestLargeString(); + void TestSqlIngestStringView(); void TestSqlIngestBinary(); + void TestSqlIngestLargeBinary(); + void TestSqlIngestFixedSizeBinary(); + void TestSqlIngestBinaryView(); // Temporal void TestSqlIngestDuration(); @@ -377,6 +402,12 @@ class StatementTest { // Dictionary-encoded void TestSqlIngestStringDictionary(); + // Nested + void TestSqlIngestListOfInt32(); + void TestSqlIngestListOfString(); + + void TestSqlIngestStreamZeroArrays(); + // ---- End Type-specific tests ---------------- void TestSqlIngestTableEscaping(); @@ -407,6 +438,8 @@ class StatementTest { void TestSqlPrepareErrorNoQuery(); void TestSqlPrepareErrorParamCountMismatch(); + void TestSqlBind(); + void TestSqlQueryEmpty(); void TestSqlQueryInts(); void TestSqlQueryFloats(); @@ -438,6 +471,11 @@ class StatementTest { struct AdbcConnection connection; struct AdbcStatement statement; + template + void TestSqlIngestType(SchemaField type, + const std::vector>& values, + bool dictionary_encode); + template void TestSqlIngestType(ArrowType type, const std::vector>& values, bool dictionary_encode); @@ -453,6 +491,14 @@ class StatementTest { const char* timezone); }; +template +void StatementTest::TestSqlIngestType(ArrowType type, + const std::vector>& values, + bool dictionary_encode) { + SchemaField field("col", type); + TestSqlIngestType(field, values, dictionary_encode); +} + #define ADBCV_TEST_STATEMENT(FIXTURE) \ static_assert(std::is_base_of::value, \ ADBCV_STRINGIFY(FIXTURE) " must inherit from StatementTest"); \ @@ -467,17 +513,25 @@ class StatementTest { TEST_F(FIXTURE, SqlIngestUInt16) { TestSqlIngestUInt16(); } \ TEST_F(FIXTURE, SqlIngestUInt32) { TestSqlIngestUInt32(); } \ TEST_F(FIXTURE, SqlIngestUInt64) { TestSqlIngestUInt64(); } \ + TEST_F(FIXTURE, SqlIngestFloat16) { TestSqlIngestFloat16(); } \ TEST_F(FIXTURE, SqlIngestFloat32) { TestSqlIngestFloat32(); } \ TEST_F(FIXTURE, SqlIngestFloat64) { TestSqlIngestFloat64(); } \ TEST_F(FIXTURE, SqlIngestString) { TestSqlIngestString(); } \ TEST_F(FIXTURE, SqlIngestLargeString) { TestSqlIngestLargeString(); } \ + TEST_F(FIXTURE, SqlIngestStringView) { TestSqlIngestStringView(); } \ TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); } \ + TEST_F(FIXTURE, SqlIngestLargeBinary) { TestSqlIngestLargeBinary(); } \ + TEST_F(FIXTURE, SqlIngestFixedSizeBinary) { TestSqlIngestFixedSizeBinary(); } \ + TEST_F(FIXTURE, SqlIngestBinaryView) { TestSqlIngestBinaryView(); } \ TEST_F(FIXTURE, SqlIngestDuration) { TestSqlIngestDuration(); } \ TEST_F(FIXTURE, SqlIngestDate32) { TestSqlIngestDate32(); } \ TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \ TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); } \ TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \ TEST_F(FIXTURE, SqlIngestStringDictionary) { TestSqlIngestStringDictionary(); } \ + TEST_F(FIXTURE, SqlIngestListOfInt32) { TestSqlIngestListOfInt32(); } \ + TEST_F(FIXTURE, SqlIngestListOfString) { TestSqlIngestListOfString(); } \ + TEST_F(FIXTURE, TestSqlIngestStreamZeroArrays) { TestSqlIngestStreamZeroArrays(); } \ TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); } \ TEST_F(FIXTURE, SqlIngestColumnEscaping) { TestSqlIngestColumnEscaping(); } \ TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \ @@ -505,6 +559,7 @@ class StatementTest { TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) { \ TestSqlPrepareErrorParamCountMismatch(); \ } \ + TEST_F(FIXTURE, SqlBind) { TestSqlBind(); } \ TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); } \ TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \ TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \ diff --git a/c/validation/adbc_validation_connection.cc b/c/validation/adbc_validation_connection.cc index 4532c18556..032f1d328f 100644 --- a/c/validation/adbc_validation_connection.cc +++ b/c/validation/adbc_validation_connection.cc @@ -19,7 +19,7 @@ #include -#include +#include #include #include @@ -425,10 +425,10 @@ void CheckGetObjectsSchema(struct ArrowSchema* schema) { {"constraint_column_names", NANOARROW_TYPE_LIST, NOT_NULL}, {"constraint_column_usage", NANOARROW_TYPE_LIST, NULLABLE}, })); - ASSERT_NO_FATAL_FAILURE(CompareSchema( - constraint_schema->children[2], { - {std::nullopt, NANOARROW_TYPE_STRING, NULLABLE}, - })); + ASSERT_NO_FATAL_FAILURE(CompareSchema(constraint_schema->children[2], + { + {"", NANOARROW_TYPE_STRING, NULLABLE}, + })); struct ArrowSchema* usage_schema = constraint_schema->children[3]->children[0]; ASSERT_NO_FATAL_FAILURE( @@ -744,13 +744,15 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { struct TestCase { std::optional filter; - std::vector column_names; - std::vector ordinal_positions; + // the pair is column name & ordinal position of the column + std::vector> columns; }; std::vector test_cases; - test_cases.push_back({std::nullopt, {"int64s", "strings"}, {1, 2}}); - test_cases.push_back({"in%", {"int64s"}, {1}}); + test_cases.push_back({std::nullopt, {{"int64s", 1}, {"strings", 2}}}); + test_cases.push_back({"in%", {{"int64s", 1}}}); + + const std::string catalog = quirks()->catalog(); for (const auto& test_case : test_cases) { std::string scope = "Filter: "; @@ -758,13 +760,14 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { SCOPED_TRACE(scope); StreamReader reader; + std::vector> columns; std::vector column_names; std::vector ordinal_positions; ASSERT_THAT( AdbcConnectionGetObjects( - &connection, ADBC_OBJECT_DEPTH_COLUMNS, nullptr, nullptr, nullptr, nullptr, - test_case.filter.has_value() ? test_case.filter->c_str() : nullptr, + &connection, ADBC_OBJECT_DEPTH_COLUMNS, catalog.c_str(), nullptr, nullptr, + nullptr, test_case.filter.has_value() ? test_case.filter->c_str() : nullptr, &reader.stream.value, &error), IsOkStatus(&error)); ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); @@ -834,10 +837,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { std::string temp(name.data, name.size_bytes); std::transform(temp.begin(), temp.end(), temp.begin(), [](unsigned char c) { return std::tolower(c); }); - column_names.push_back(std::move(temp)); - ordinal_positions.push_back( - static_cast(ArrowArrayViewGetIntUnsafe( - table_columns->children[1], columns_index))); + columns.emplace_back(std::move(temp), + static_cast(ArrowArrayViewGetIntUnsafe( + table_columns->children[1], columns_index))); } } } @@ -847,8 +849,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { } while (reader.array->release); ASSERT_TRUE(found_expected_table) << "Did (not) find table in metadata"; - ASSERT_EQ(test_case.column_names, column_names); - ASSERT_EQ(test_case.ordinal_positions, ordinal_positions); + // metadata columns do not guarantee the order they are returned in, just + // validate all the elements are there. + ASSERT_THAT(columns, testing::UnorderedElementsAreArray(test_case.columns)); } } diff --git a/c/validation/adbc_validation_database.cc b/c/validation/adbc_validation_database.cc index 371226cc42..78a3a19999 100644 --- a/c/validation/adbc_validation_database.cc +++ b/c/validation/adbc_validation_database.cc @@ -19,7 +19,7 @@ #include -#include +#include #include #include diff --git a/c/validation/adbc_validation_statement.cc b/c/validation/adbc_validation_statement.cc index 333baf1414..cd388623ba 100644 --- a/c/validation/adbc_validation_statement.cc +++ b/c/validation/adbc_validation_statement.cc @@ -19,7 +19,7 @@ #include -#include +#include #include #include #include @@ -79,9 +79,12 @@ void StatementTest::TestRelease() { } template -void StatementTest::TestSqlIngestType(ArrowType type, +void StatementTest::TestSqlIngestType(SchemaField field, const std::vector>& values, bool dictionary_encode) { + // Override the field name + field.name = "col"; + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -92,7 +95,7 @@ void StatementTest::TestSqlIngestType(ArrowType type, Handle schema; Handle array; struct ArrowError na_error; - ASSERT_THAT(MakeSchema(&schema.value, {{"col", type}}), IsOkErrno()); + ASSERT_THAT(MakeSchema(&schema.value, {field}), IsOkErrno()); ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, values), IsOkErrno()); @@ -155,16 +158,15 @@ void StatementTest::TestSqlIngestType(ArrowType type, ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1))); ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); - ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(type); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(&reader.schema.value, {{"col", round_trip_type, NULLABLE}})); + SchemaField round_trip_field = quirks()->IngestSelectRoundTripType(field); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, {round_trip_field})); ASSERT_NO_FATAL_FAILURE(reader.Next()); ASSERT_NE(nullptr, reader.array->release); ASSERT_EQ(values.size(), reader.array->length); ASSERT_EQ(1, reader.array->n_children); - if (round_trip_type == type) { + if (round_trip_field.type == field.type) { // XXX: for now we can't compare values; we would need casting ASSERT_NO_FATAL_FAILURE( CompareArray(reader.array_view->children[0], values)); @@ -235,6 +237,14 @@ void StatementTest::TestSqlIngestInt64() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_INT64)); } +void StatementTest::TestSqlIngestFloat16() { + if (!quirks()->supports_ingest_float16()) { + GTEST_SKIP(); + } + + ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_HALF_FLOAT)); +} + void StatementTest::TestSqlIngestFloat32() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_FLOAT)); } @@ -253,6 +263,16 @@ void StatementTest::TestSqlIngestLargeString() { NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"}, false)); } +void StatementTest::TestSqlIngestStringView() { + if (!quirks()->supports_ingest_view_types()) { + GTEST_SKIP(); + } + + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( + NANOARROW_TYPE_STRING_VIEW, {std::nullopt, "", "", "longer than 12 bytes", "例"}, + false)); +} + void StatementTest::TestSqlIngestBinary() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( NANOARROW_TYPE_BINARY, @@ -264,6 +284,38 @@ void StatementTest::TestSqlIngestBinary() { false)); } +void StatementTest::TestSqlIngestLargeBinary() { + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( + NANOARROW_TYPE_LARGE_BINARY, + {std::nullopt, std::vector{}, + std::vector{std::byte{0x00}, std::byte{0x01}}, + std::vector{std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, + std::byte{0x04}}, + std::vector{std::byte{0xfe}, std::byte{0xff}}}, + false)); +} + +void StatementTest::TestSqlIngestFixedSizeBinary() { + SchemaField field = SchemaField::FixedSize("col", NANOARROW_TYPE_FIXED_SIZE_BINARY, 4); + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( + field, {std::nullopt, "abcd", "efgh", "ijkl", "mnop"}, false)); +} + +void StatementTest::TestSqlIngestBinaryView() { + if (!quirks()->supports_ingest_view_types()) { + GTEST_SKIP(); + } + + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( + NANOARROW_TYPE_LARGE_BINARY, + {std::nullopt, std::vector{}, + std::vector{std::byte{0x00}, std::byte{0x01}}, + std::vector{std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, + std::byte{0x04}}, + std::vector{std::byte{0xfe}, std::byte{0xff}}}, + false)); +} + void StatementTest::TestSqlIngestDate32() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_DATE32)); } @@ -491,6 +543,70 @@ void StatementTest::TestSqlIngestStringDictionary() { /*dictionary_encode*/ true)); } +void StatementTest::TestSqlIngestListOfInt32() { + SchemaField field = + SchemaField::Nested("col", NANOARROW_TYPE_LIST, {{"item", NANOARROW_TYPE_INT32}}); + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( + field, {std::nullopt, std::vector{1, 2, 3}, std::vector{4, 5}}, + /*dictionary_encode*/ false)); +} + +void StatementTest::TestSqlIngestListOfString() { + SchemaField field = + SchemaField::Nested("col", NANOARROW_TYPE_LIST, {{"item", NANOARROW_TYPE_STRING}}); + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( + field, + {std::nullopt, std::vector{"abc", "defg"}, + std::vector{"hijk"}}, + /*dictionary_encode*/ false)); +} + +void StatementTest::TestSqlIngestStreamZeroArrays() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { + GTEST_SKIP(); + } + + ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), + IsOkStatus(&error)); + + Handle schema; + ASSERT_THAT(MakeSchema(&schema.value, {{"col", NANOARROW_TYPE_INT32}}), IsOkErrno()); + + Handle bind; + nanoarrow::EmptyArrayStream(&schema.value).ToArrayStream(&bind.value); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBindStream(&statement, &bind.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "SELECT * FROM \"bulk_ingest\"", &error), + IsOkStatus(&error)); + + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(NANOARROW_TYPE_INT32); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(&reader.schema.value, {{"col", round_trip_type, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } +} + void StatementTest::TestSqlIngestTableEscaping() { std::string name = "create_table_escaping"; @@ -2062,18 +2178,83 @@ void StatementTest::TestSqlPrepareErrorParamCountMismatch() { ::testing::Not(IsOkStatus(&error))); } +void StatementTest::TestSqlBind() { + if (!quirks()->supports_dynamic_parameter_binding()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + ASSERT_THAT(quirks()->DropTable(&connection, "bindtest", &error), IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement, "CREATE TABLE bindtest (col1 INTEGER, col2 TEXT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, + {{"", NANOARROW_TYPE_INT32}, {"", NANOARROW_TYPE_STRING}}), + IsOkErrno()); + + std::vector> int_values{std::nullopt, -123, 123}; + std::vector> string_values{"abc", std::nullopt, "defg"}; + + int batch_result = MakeBatch( + &schema.value, &array.value, &na_error, int_values, string_values); + ASSERT_THAT(batch_result, IsOkErrno()); + + auto insert_query = std::string("INSERT INTO bindtest VALUES (") + + quirks()->BindParameter(0) + ", " + quirks()->BindParameter(1) + + ")"; + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, insert_query.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + int64_t rows_affected = -10; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(-1), ::testing::Eq(3))); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "SELECT * FROM bindtest ORDER BY col1 ASC NULLS FIRST", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 3); + CompareArray(reader.array_view->children[0], int_values); + CompareArray(reader.array_view->children[1], string_values); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } +} + void StatementTest::TestSqlQueryEmpty() { ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); - ASSERT_THAT(quirks()->DropTable(&connection, "QUERYEMPTY", &error), IsOkStatus(&error)); + ASSERT_THAT(quirks()->DropTable(&connection, "queryempty", &error), IsOkStatus(&error)); ASSERT_THAT( - AdbcStatementSetSqlQuery(&statement, "CREATE TABLE QUERYEMPTY (FOO INT)", &error), + AdbcStatementSetSqlQuery(&statement, "CREATE TABLE queryempty (FOO INT)", &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), IsOkStatus(&error)); ASSERT_THAT( - AdbcStatementSetSqlQuery(&statement, "SELECT * FROM QUERYEMPTY WHERE 1=0", &error), + AdbcStatementSetSqlQuery(&statement, "SELECT * FROM queryempty WHERE 1=0", &error), IsOkStatus(&error)); { StreamReader reader; diff --git a/c/validation/adbc_validation_util.cc b/c/validation/adbc_validation_util.cc index 24310aba3d..7d97ad7626 100644 --- a/c/validation/adbc_validation_util.cc +++ b/c/validation/adbc_validation_util.cc @@ -16,7 +16,7 @@ // under the License. #include "adbc_validation_util.h" -#include +#include #include "adbc_validation.h" @@ -36,6 +36,20 @@ std::optional ConnectionGetOption(struct AdbcConnection* connection return std::string(buffer, buffer_size - 1); } +std::optional StatementGetOption(struct AdbcStatement* statement, + std::string_view option, + struct AdbcError* error) { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + AdbcStatusCode status = + AdbcStatementGetOption(statement, option.data(), buffer, &buffer_size, error); + EXPECT_THAT(status, IsOkStatus(error)); + if (status != ADBC_STATUS_OK) return std::nullopt; + EXPECT_GT(buffer_size, 0); + if (buffer_size == 0) return std::nullopt; + return std::string(buffer, buffer_size - 1); +} + std::string StatusCodeToString(AdbcStatusCode code) { #define CASE(CONSTANT) \ case ADBC_STATUS_##CONSTANT: \ @@ -151,16 +165,53 @@ ::testing::Matcher IsStatus(AdbcStatusCode code, } \ } while (false); +static int MakeSchemaColumnImpl(struct ArrowSchema* column, const SchemaField& field) { + switch (field.type) { + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_LIST: + CHECK_ERRNO(ArrowSchemaSetTypeFixedSize(column, field.type, field.fixed_size)); + break; + default: + CHECK_ERRNO(ArrowSchemaSetType(column, field.type)); + break; + } + + CHECK_ERRNO(ArrowSchemaSetName(column, field.name.c_str())); + + if (!field.nullable) { + column->flags &= ~ARROW_FLAG_NULLABLE; + } + + if (static_cast(column->n_children) != field.children.size()) { + return EINVAL; + } + + switch (field.type) { + // SetType for a list will allocate and initialize children + case NANOARROW_TYPE_LIST: + case NANOARROW_TYPE_LARGE_LIST: + case NANOARROW_TYPE_FIXED_SIZE_LIST: + case NANOARROW_TYPE_MAP: { + size_t i = 0; + for (const SchemaField& child : field.children) { + CHECK_ERRNO(MakeSchemaColumnImpl(column->children[i], child)); + ++i; + } + break; + } + default: + break; + } + + return 0; +} + int MakeSchema(struct ArrowSchema* schema, const std::vector& fields) { ArrowSchemaInit(schema); CHECK_ERRNO(ArrowSchemaSetTypeStruct(schema, fields.size())); size_t i = 0; for (const SchemaField& field : fields) { - CHECK_ERRNO(ArrowSchemaSetType(schema->children[i], field.type)); - CHECK_ERRNO(ArrowSchemaSetName(schema->children[i], field.name.c_str())); - if (!field.nullable) { - schema->children[i]->flags &= ~ARROW_FLAG_NULLABLE; - } + CHECK_ERRNO(MakeSchemaColumnImpl(schema->children[i], field)); i++; } return 0; @@ -230,9 +281,7 @@ void MakeStream(struct ArrowArrayStream* stream, struct ArrowSchema* schema, stream->private_data = new ConstantArrayStream(schema, std::move(batches)); } -void CompareSchema( - struct ArrowSchema* schema, - const std::vector, ArrowType, bool>>& fields) { +void CompareSchema(struct ArrowSchema* schema, const std::vector& fields) { struct ArrowError na_error; struct ArrowSchemaView view; @@ -247,12 +296,11 @@ void CompareSchema( struct ArrowSchemaView field_view; ASSERT_THAT(ArrowSchemaViewInit(&field_view, schema->children[i], &na_error), IsOkErrno(&na_error)); - ASSERT_EQ(std::get<1>(fields[i]), field_view.type); - ASSERT_EQ(std::get<2>(fields[i]), - (schema->children[i]->flags & ARROW_FLAG_NULLABLE) != 0) + ASSERT_EQ(fields[i].type, field_view.type); + ASSERT_EQ(fields[i].nullable, (schema->children[i]->flags & ARROW_FLAG_NULLABLE) != 0) << "Nullability mismatch"; - if (std::get<0>(fields[i]).has_value()) { - ASSERT_STRCASEEQ(std::get<0>(fields[i])->c_str(), schema->children[i]->name); + if (fields[i].name != "") { + ASSERT_STRCASEEQ(fields[i].name.c_str(), schema->children[i]->name); } } } diff --git a/c/validation/adbc_validation_util.h b/c/validation/adbc_validation_util.h index 683103e175..b4f5d6f81a 100644 --- a/c/validation/adbc_validation_util.h +++ b/c/validation/adbc_validation_util.h @@ -27,7 +27,7 @@ #include #include -#include +#include #include #include #include @@ -43,6 +43,10 @@ std::optional ConnectionGetOption(struct AdbcConnection* connection std::string_view option, struct AdbcError* error); +std::optional StatementGetOption(struct AdbcStatement* statement, + std::string_view option, + struct AdbcError* error); + // ------------------------------------------------------------ // Helpers to print values @@ -60,6 +64,11 @@ std::string ToString(struct ArrowArrayStream* stream); // ------------------------------------------------------------ // Helper to manage C Data Interface/Nanoarrow resources with RAII +template +struct Initializer { + static void Initialize(T* value) { memset(value, 0, sizeof(T)); } +}; + template struct Releaser { static void Release(T* value) { @@ -69,6 +78,11 @@ struct Releaser { } }; +template <> +struct Initializer { + static void Initialize(struct ArrowBuffer* value) { ArrowBufferInit(value); } +}; + template <> struct Releaser { static void Release(struct ArrowBuffer* buffer) { ArrowBufferReset(buffer); } @@ -126,7 +140,7 @@ template struct Handle { Resource value; - Handle() { std::memset(&value, 0, sizeof(value)); } + Handle() { Initializer::Initialize(&value); } ~Handle() { Releaser::Release(&value); } @@ -242,13 +256,29 @@ struct GetObjectsReader { struct SchemaField { std::string name; ArrowType type = NANOARROW_TYPE_UNINITIALIZED; + int32_t fixed_size = 0; bool nullable = true; + std::vector children; SchemaField(std::string name, ArrowType type, bool nullable) : name(std::move(name)), type(type), nullable(nullable) {} SchemaField(std::string name, ArrowType type) : SchemaField(std::move(name), type, /*nullable=*/true) {} + + static SchemaField Nested(std::string name, ArrowType type, + std::vector children) { + SchemaField out(name, type); + out.children = std::move(children); + return out; + } + + static SchemaField FixedSize(std::string name, ArrowType type, int32_t fixed_size, + std::vector children = {}) { + SchemaField out = Nested(name, type, std::move(children)); + out.fixed_size = fixed_size; + return out; + } }; /// \brief Make a schema from a vector of (name, type, nullable) tuples. @@ -263,54 +293,61 @@ int MakeArray(struct ArrowArray* parent, struct ArrowArray* array, if constexpr (std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value) { - if (int errno_res = ArrowArrayAppendInt(array, *v); errno_res != 0) { - return errno_res; - } + CHECK_OK(ArrowArrayAppendInt(array, *v)); // XXX: cpplint gets weird here and thinks this is an unbraced if } else if constexpr (std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value) { - if (int errno_res = ArrowArrayAppendUInt(array, *v); errno_res != 0) { - return errno_res; - } + CHECK_OK(ArrowArrayAppendUInt(array, *v)); } else if constexpr (std::is_same::value || // NOLINT(readability/braces) std::is_same::value) { - if (int errno_res = ArrowArrayAppendDouble(array, *v); errno_res != 0) { - return errno_res; - } + CHECK_OK(ArrowArrayAppendDouble(array, *v)); } else if constexpr (std::is_same::value) { struct ArrowBufferView view; view.data.as_char = v->c_str(); view.size_bytes = v->size(); - if (int errno_res = ArrowArrayAppendBytes(array, view); errno_res != 0) { - return errno_res; - } + CHECK_OK(ArrowArrayAppendBytes(array, view)); } else if constexpr (std::is_same>::value) { static_assert(std::is_same_v); struct ArrowBufferView view; view.data.as_uint8 = reinterpret_cast(v->data()); view.size_bytes = v->size(); - if (int errno_res = ArrowArrayAppendBytes(array, view); errno_res != 0) { - return errno_res; - } + CHECK_OK(ArrowArrayAppendBytes(array, view)); } else if constexpr (std::is_same::value) { - if (int errno_res = ArrowArrayAppendInterval(array, *v); errno_res != 0) { - return errno_res; - } + CHECK_OK(ArrowArrayAppendInterval(array, *v)); } else if constexpr (std::is_same::value) { - if (int errno_res = ArrowArrayAppendDecimal(array, *v); errno_res != 0) { - return errno_res; + CHECK_OK(ArrowArrayAppendDecimal(array, *v)); + } else if constexpr ( + // Possibly a more effective way to do this using template magic + // Not included but possible are the std::optional<> variants of this + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>>::value) { + using child_t = typename T::value_type; + std::vector> value_nullable; + for (const auto& child_value : *v) { + value_nullable.push_back(child_value); } + CHECK_OK(MakeArray(array, array->children[0], value_nullable)); + CHECK_OK(ArrowArrayFinishElement(array)); } else { static_assert(!sizeof(T), "Not yet implemented"); return ENOTSUP; } } else { - if (int errno_res = ArrowArrayAppendNull(array, 1); errno_res != 0) { - return errno_res; - } + CHECK_OK(ArrowArrayAppendNull(array, 1)); } } return 0; @@ -326,10 +363,7 @@ template int MakeBatchImpl(struct ArrowArray* batch, size_t i, struct ArrowError* error, const std::vector>& first, const std::vector>&... rest) { - if (int errno_res = MakeArray(batch, batch->children[i], first); - errno_res != 0) { - return errno_res; - } + CHECK_OK(MakeArray(batch, batch->children[i], first)); return MakeBatchImpl(batch, i + 1, error, rest...); } @@ -337,12 +371,8 @@ int MakeBatchImpl(struct ArrowArray* batch, size_t i, struct ArrowError* error, template int MakeBatch(struct ArrowArray* batch, struct ArrowError* error, const std::vector>&... columns) { - if (int errno_res = ArrowArrayStartAppending(batch); errno_res != 0) { - return errno_res; - } - if (int errno_res = MakeBatchImpl(batch, 0, error, columns...); errno_res != 0) { - return errno_res; - } + CHECK_OK(ArrowArrayStartAppending(batch)); + CHECK_OK(MakeBatchImpl(batch, 0, error, columns...)); for (size_t i = 0; i < static_cast(batch->n_children); i++) { if (batch->length > 0 && batch->children[i]->length != batch->length) { ADD_FAILURE() << "Column lengths are inconsistent: column " << i << " has length " @@ -357,9 +387,7 @@ int MakeBatch(struct ArrowArray* batch, struct ArrowError* error, template int MakeBatch(struct ArrowSchema* schema, struct ArrowArray* batch, struct ArrowError* error, const std::vector>&... columns) { - if (int errno_res = ArrowArrayInitFromSchema(batch, schema, error); errno_res != 0) { - return errno_res; - } + CHECK_OK(ArrowArrayInitFromSchema(batch, schema, error)); return MakeBatch(batch, error, columns...); } @@ -370,49 +398,33 @@ void MakeStream(struct ArrowArrayStream* stream, struct ArrowSchema* schema, /// \brief Compare an array for equality against a vector of values. template void CompareArray(struct ArrowArrayView* array, - const std::vector>& values) { - ASSERT_EQ(static_cast(values.size()), array->array->length); - int64_t i = 0; + const std::vector>& values, int64_t offset = 0, + int64_t length = -1) { + if (length == -1) { + length = array->length; + } + ASSERT_EQ(static_cast(values.size()), length); + int64_t i = offset; for (const auto& v : values) { SCOPED_TRACE("Array index " + std::to_string(i)); if (v.has_value()) { ASSERT_FALSE(ArrowArrayViewIsNull(array, i)); - if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_double[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, ArrowBitGet(array->buffer_views[1].data.as_uint8, i)); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int8[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int16[i]); - } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int32[i]); - } else if constexpr (std::is_same::value) { + ASSERT_EQ(ArrowArrayViewGetDoubleUnsafe(array, i), *v); + } else if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int64[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint8[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint16[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint32[i]); - } else if constexpr (std::is_same::value) { + ASSERT_EQ(ArrowArrayViewGetIntUnsafe(array, i), *v); + } else if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint64[i]); + ASSERT_EQ(ArrowArrayViewGetUIntUnsafe(array, i), *v); } else if constexpr (std::is_same::value) { struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(array, i); std::string str(view.data, view.size_bytes); @@ -432,6 +444,34 @@ void CompareArray(struct ArrowArrayView* array, ASSERT_EQ(interval.months, (*v)->months); ASSERT_EQ(interval.days, (*v)->days); ASSERT_EQ(interval.ns, (*v)->ns); + + } else if constexpr ( + // Possibly a more effective way to do this using template magic + // Not included but possible are the std::optional<> variants of this + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>>::value) { + using child_t = typename T::value_type; + std::vector> value_nullable; + for (const auto& child_value : *v) { + value_nullable.push_back(child_value); + } + + SCOPED_TRACE("List item"); + int64_t child_offset = ArrowArrayViewListChildOffset(array, i); + int64_t child_length = ArrowArrayViewListChildOffset(array, i + 1) - child_offset; + CompareArray(array->children[0], value_nullable, child_offset, + child_length); } else { static_assert(!sizeof(T), "Not yet implemented"); } @@ -444,9 +484,7 @@ void CompareArray(struct ArrowArrayView* array, /// \brief Compare a schema for equality against a vector of (name, /// type, nullable) tuples. -void CompareSchema( - struct ArrowSchema* schema, - const std::vector, ArrowType, bool>>& fields); +void CompareSchema(struct ArrowSchema* schema, const std::vector& fields); /// \brief Helper method to get the vendor version of a driver std::string GetDriverVendorVersion(struct AdbcConnection* connection); diff --git a/c/validation/meson.build b/c/validation/meson.build new file mode 100644 index 0000000000..984f4a34fb --- /dev/null +++ b/c/validation/meson.build @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +adbc_validation_util_lib = static_library( + 'adbc_validation_util', + 'adbc_validation_util.cc', + include_directories: [include_dir, driver_dir], + link_with: [adbc_common_lib, adbc_framework_lib, adbc_driver_manager_lib], + dependencies: [nanoarrow_dep, gtest_main_dep, gmock_dep], +) + +adbc_validation_dep = declare_dependency( + sources: [ + 'adbc_validation.cc', + 'adbc_validation_connection.cc', + 'adbc_validation_database.cc', + 'adbc_validation_statement.cc', + ], + include_directories: [include_dir, driver_dir], + link_with: [ + adbc_validation_util_lib, + adbc_common_lib, + adbc_framework_lib, + adbc_driver_manager_lib, + ], + dependencies: [nanoarrow_dep, gtest_main_dep, gmock_dep], +) diff --git a/c/vendor/nanoarrow/nanoarrow.c b/c/vendor/nanoarrow/nanoarrow.c index 0af57027a5..8f2659881b 100644 --- a/c/vendor/nanoarrow/nanoarrow.c +++ b/c/vendor/nanoarrow/nanoarrow.c @@ -66,6 +66,7 @@ void ArrowLayoutInit(struct ArrowLayout* layout, enum ArrowType storage_type) { switch (storage_type) { case NANOARROW_TYPE_UNINITIALIZED: case NANOARROW_TYPE_NA: + case NANOARROW_TYPE_RUN_END_ENCODED: layout->buffer_type[0] = NANOARROW_BUFFER_TYPE_NONE; layout->buffer_data_type[0] = NANOARROW_TYPE_UNINITIALIZED; layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_NONE; @@ -178,6 +179,16 @@ void ArrowLayoutInit(struct ArrowLayout* layout, enum ArrowType storage_type) { layout->buffer_data_type[2] = NANOARROW_TYPE_BINARY; break; + case NANOARROW_TYPE_BINARY_VIEW: + layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; + layout->buffer_data_type[1] = NANOARROW_TYPE_BINARY_VIEW; + layout->element_size_bits[1] = 128; + break; + case NANOARROW_TYPE_STRING_VIEW: + layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; + layout->buffer_data_type[1] = NANOARROW_TYPE_STRING_VIEW; + layout->element_size_bits[1] = 128; + default: break; } @@ -201,7 +212,9 @@ static void ArrowBufferAllocatorMallocFree(struct ArrowBufferAllocator* allocato uint8_t* ptr, int64_t size) { NANOARROW_UNUSED(allocator); NANOARROW_UNUSED(size); - ArrowFree(ptr); + if (ptr != NULL) { + ArrowFree(ptr); + } } static struct ArrowBufferAllocator ArrowBufferAllocatorMalloc = { @@ -211,13 +224,24 @@ struct ArrowBufferAllocator ArrowBufferAllocatorDefault(void) { return ArrowBufferAllocatorMalloc; } -static uint8_t* ArrowBufferAllocatorNeverReallocate( - struct ArrowBufferAllocator* allocator, uint8_t* ptr, int64_t old_size, - int64_t new_size) { - NANOARROW_UNUSED(allocator); - NANOARROW_UNUSED(ptr); - NANOARROW_UNUSED(old_size); +static uint8_t* ArrowBufferDeallocatorReallocate(struct ArrowBufferAllocator* allocator, + uint8_t* ptr, int64_t old_size, + int64_t new_size) { NANOARROW_UNUSED(new_size); + + // Attempting to reallocate a buffer with a custom deallocator is + // a programming error. In debug mode, crash here. +#if defined(NANOARROW_DEBUG) + NANOARROW_PRINT_AND_DIE(ENOMEM, + "It is an error to reallocate a buffer whose allocator is " + "ArrowBufferDeallocator()"); +#endif + + // In release mode, ensure the the deallocator is called exactly + // once using the pointer it was given and return NULL, which + // will trigger the caller to return ENOMEM. + allocator->free(allocator, ptr, old_size); + *allocator = ArrowBufferAllocatorDefault(); return NULL; } @@ -226,7 +250,7 @@ struct ArrowBufferAllocator ArrowBufferDeallocator( int64_t size), void* private_data) { struct ArrowBufferAllocator allocator; - allocator.reallocate = &ArrowBufferAllocatorNeverReallocate; + allocator.reallocate = &ArrowBufferDeallocatorReallocate; allocator.free = custom_free; allocator.private_data = private_data; return allocator; @@ -332,6 +356,7 @@ ArrowErrorCode ArrowDecimalSetDigits(struct ArrowDecimal* decimal, // https://github.com/apache/arrow/blob/cd3321b28b0c9703e5d7105d6146c1270bbadd7f/cpp/src/arrow/util/decimal.cc#L365 ArrowErrorCode ArrowDecimalAppendDigitsToBuffer(const struct ArrowDecimal* decimal, struct ArrowBuffer* buffer) { + NANOARROW_DCHECK(decimal->n_words == 2 || decimal->n_words == 4); int is_negative = ArrowDecimalSign(decimal) < 0; uint64_t words_little_endian[4]; @@ -417,6 +442,13 @@ ArrowErrorCode ArrowDecimalAppendDigitsToBuffer(const struct ArrowDecimal* decim // The most significant segment should have no leading zeroes int n_chars = snprintf((char*)buffer->data + buffer->size_bytes, 21, "%lu", (unsigned long)segments[num_segments - 1]); + + // Ensure that an encoding error from snprintf() does not result + // in an out-of-bounds access. + if (n_chars < 0) { + return ERANGE; + } + buffer->size_bytes += n_chars; // Subsequent output needs to be left-padded with zeroes such that each segment @@ -448,6 +480,7 @@ ArrowErrorCode ArrowDecimalAppendDigitsToBuffer(const struct ArrowDecimal* decim // under the License. #include +#include #include #include #include @@ -532,8 +565,12 @@ static const char* ArrowSchemaFormatTemplate(enum ArrowType type) { return "u"; case NANOARROW_TYPE_LARGE_STRING: return "U"; + case NANOARROW_TYPE_STRING_VIEW: + return "vu"; case NANOARROW_TYPE_BINARY: return "z"; + case NANOARROW_TYPE_BINARY_VIEW: + return "vz"; case NANOARROW_TYPE_LARGE_BINARY: return "Z"; @@ -556,6 +593,8 @@ static const char* ArrowSchemaFormatTemplate(enum ArrowType type) { return "+s"; case NANOARROW_TYPE_MAP: return "+m"; + case NANOARROW_TYPE_RUN_END_ENCODED: + return "+r"; default: return NULL; @@ -587,6 +626,13 @@ static int ArrowSchemaInitChildrenIfNeeded(struct ArrowSchema* schema, NANOARROW_RETURN_NOT_OK( ArrowSchemaSetName(schema->children[0]->children[1], "value")); break; + case NANOARROW_TYPE_RUN_END_ENCODED: + NANOARROW_RETURN_NOT_OK(ArrowSchemaAllocateChildren(schema, 2)); + ArrowSchemaInit(schema->children[0]); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetName(schema->children[0], "run_ends")); + schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + ArrowSchemaInit(schema->children[1]); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetName(schema->children[1], "values")); default: break; } @@ -656,15 +702,19 @@ ArrowErrorCode ArrowSchemaSetTypeFixedSize(struct ArrowSchema* schema, int n_chars; switch (type) { case NANOARROW_TYPE_FIXED_SIZE_BINARY: - n_chars = snprintf(buffer, sizeof(buffer), "w:%d", (int)fixed_size); + n_chars = snprintf(buffer, sizeof(buffer), "w:%" PRId32, fixed_size); break; case NANOARROW_TYPE_FIXED_SIZE_LIST: - n_chars = snprintf(buffer, sizeof(buffer), "+w:%d", (int)fixed_size); + n_chars = snprintf(buffer, sizeof(buffer), "+w:%" PRId32, fixed_size); break; default: return EINVAL; } + if (((size_t)n_chars) >= sizeof(buffer) || n_chars < 0) { + return ERANGE; + } + buffer[n_chars] = '\0'; NANOARROW_RETURN_NOT_OK(ArrowSchemaSetFormat(schema, buffer)); @@ -697,10 +747,36 @@ ArrowErrorCode ArrowSchemaSetTypeDecimal(struct ArrowSchema* schema, enum ArrowT return EINVAL; } + if (((size_t)n_chars) >= sizeof(buffer) || n_chars < 0) { + return ERANGE; + } + buffer[n_chars] = '\0'; return ArrowSchemaSetFormat(schema, buffer); } +ArrowErrorCode ArrowSchemaSetTypeRunEndEncoded(struct ArrowSchema* schema, + enum ArrowType run_end_type) { + switch (run_end_type) { + case NANOARROW_TYPE_INT16: + case NANOARROW_TYPE_INT32: + case NANOARROW_TYPE_INT64: + break; + default: + return EINVAL; + } + + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetFormat( + schema, ArrowSchemaFormatTemplate(NANOARROW_TYPE_RUN_END_ENCODED))); + NANOARROW_RETURN_NOT_OK( + ArrowSchemaInitChildrenIfNeeded(schema, NANOARROW_TYPE_RUN_END_ENCODED)); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema->children[0], run_end_type)); + NANOARROW_RETURN_NOT_OK( + ArrowSchemaSetType(schema->children[1], NANOARROW_TYPE_UNINITIALIZED)); + + return NANOARROW_OK; +} + static const char* ArrowTimeUnitFormatString(enum ArrowTimeUnit time_unit) { switch (time_unit) { case NANOARROW_TIME_UNIT_SECOND: @@ -773,7 +849,7 @@ ArrowErrorCode ArrowSchemaSetTypeDateTime(struct ArrowSchema* schema, enum Arrow return EINVAL; } - if (((size_t)n_chars) >= sizeof(buffer)) { + if (((size_t)n_chars) >= sizeof(buffer) || n_chars < 0) { return ERANGE; } @@ -810,18 +886,30 @@ ArrowErrorCode ArrowSchemaSetTypeUnion(struct ArrowSchema* schema, enum ArrowTyp return EINVAL; } + // Ensure that an encoding error from snprintf() does not result + // in an out-of-bounds access. + if (n_chars < 0) { + return ERANGE; + } + if (n_children > 0) { n_chars = snprintf(format_cursor, format_out_size, "0"); format_cursor += n_chars; format_out_size -= n_chars; for (int64_t i = 1; i < n_children; i++) { - n_chars = snprintf(format_cursor, format_out_size, ",%d", (int)i); + n_chars = snprintf(format_cursor, format_out_size, ",%" PRId64, i); format_cursor += n_chars; format_out_size -= n_chars; } } + // Ensure that an encoding error from snprintf() does not result + // in an out-of-bounds access. + if (n_chars < 0) { + return ERANGE; + } + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetFormat(schema, format_out)); NANOARROW_RETURN_NOT_OK(ArrowSchemaAllocateChildren(schema, n_children)); @@ -1104,8 +1192,9 @@ static ArrowErrorCode ArrowSchemaViewParse(struct ArrowSchemaView* schema_view, ArrowSchemaViewSetPrimitive(schema_view, NANOARROW_TYPE_DECIMAL256); return NANOARROW_OK; default: - ArrowErrorSet(error, "Expected decimal bitwidth of 128 or 256 but found %d", - (int)schema_view->decimal_bitwidth); + ArrowErrorSet(error, + "Expected decimal bitwidth of 128 or 256 but found %" PRId32, + schema_view->decimal_bitwidth); return EINVAL; } @@ -1162,6 +1251,13 @@ static ArrowErrorCode ArrowSchemaViewParse(struct ArrowSchemaView* schema_view, *format_end_out = format + 2; return NANOARROW_OK; + // run end encoded has no buffer at all + case 'r': + schema_view->storage_type = NANOARROW_TYPE_RUN_END_ENCODED; + schema_view->type = NANOARROW_TYPE_RUN_END_ENCODED; + *format_end_out = format + 2; + return NANOARROW_OK; + // just validity buffer case 'w': if (format[2] != ':' || format[3] == '\0') { @@ -1209,11 +1305,10 @@ static ArrowErrorCode ArrowSchemaViewParse(struct ArrowSchemaView* schema_view, int64_t n_type_ids = _ArrowParseUnionTypeIds(schema_view->union_type_ids, NULL); if (n_type_ids != schema_view->schema->n_children) { - ArrowErrorSet( - error, - "Expected union type_ids parameter to be a comma-separated list of %ld " - "values between 0 and 127 but found '%s'", - (long)schema_view->schema->n_children, schema_view->union_type_ids); + ArrowErrorSet(error, + "Expected union type_ids parameter to be a comma-separated " + "list of %" PRId64 " values between 0 and 127 but found '%s'", + schema_view->schema->n_children, schema_view->union_type_ids); return EINVAL; } *format_end_out = format + strlen(format); @@ -1392,6 +1487,24 @@ static ArrowErrorCode ArrowSchemaViewParse(struct ArrowSchemaView* schema_view, return EINVAL; } + // view types + case 'v': { + switch (format[1]) { + case 'u': + ArrowSchemaViewSetPrimitive(schema_view, NANOARROW_TYPE_STRING_VIEW); + *format_end_out = format + 2; + return NANOARROW_OK; + case 'z': + ArrowSchemaViewSetPrimitive(schema_view, NANOARROW_TYPE_BINARY_VIEW); + *format_end_out = format + 2; + return NANOARROW_OK; + default: + ArrowErrorSet(error, "Expected 'u', or 'z' following 'v' but found '%s'", + format + 1); + return EINVAL; + } + } + default: ArrowErrorSet(error, "Unknown format: '%s'", format); return EINVAL; @@ -1401,8 +1514,9 @@ static ArrowErrorCode ArrowSchemaViewParse(struct ArrowSchemaView* schema_view, static ArrowErrorCode ArrowSchemaViewValidateNChildren( struct ArrowSchemaView* schema_view, int64_t n_children, struct ArrowError* error) { if (n_children != -1 && schema_view->schema->n_children != n_children) { - ArrowErrorSet(error, "Expected schema with %d children but found %d children", - (int)n_children, (int)schema_view->schema->n_children); + ArrowErrorSet( + error, "Expected schema with %" PRId64 " children but found %" PRId64 " children", + n_children, schema_view->schema->n_children); return EINVAL; } @@ -1412,15 +1526,15 @@ static ArrowErrorCode ArrowSchemaViewValidateNChildren( for (int64_t i = 0; i < schema_view->schema->n_children; i++) { child = schema_view->schema->children[i]; if (child == NULL) { - ArrowErrorSet(error, - "Expected valid schema at schema->children[%ld] but found NULL", - (long)i); + ArrowErrorSet( + error, "Expected valid schema at schema->children[%" PRId64 "] but found NULL", + i); return EINVAL; } else if (child->release == NULL) { - ArrowErrorSet( - error, - "Expected valid schema at schema->children[%ld] but found a released schema", - (long)i); + ArrowErrorSet(error, + "Expected valid schema at schema->children[%" PRId64 + "] but found a released schema", + i); return EINVAL; } } @@ -1438,8 +1552,9 @@ static ArrowErrorCode ArrowSchemaViewValidateMap(struct ArrowSchemaView* schema_ NANOARROW_RETURN_NOT_OK(ArrowSchemaViewValidateNChildren(schema_view, 1, error)); if (schema_view->schema->children[0]->n_children != 2) { - ArrowErrorSet(error, "Expected child of map type to have 2 children but found %d", - (int)schema_view->schema->children[0]->n_children); + ArrowErrorSet(error, + "Expected child of map type to have 2 children but found %" PRId64, + schema_view->schema->children[0]->n_children); return EINVAL; } @@ -1521,6 +1636,8 @@ static ArrowErrorCode ArrowSchemaViewValidate(struct ArrowSchemaView* schema_vie case NANOARROW_TYPE_TIME32: case NANOARROW_TYPE_TIME64: case NANOARROW_TYPE_DURATION: + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: return ArrowSchemaViewValidateNChildren(schema_view, 0, error); case NANOARROW_TYPE_FIXED_SIZE_BINARY: @@ -1536,6 +1653,9 @@ static ArrowErrorCode ArrowSchemaViewValidate(struct ArrowSchemaView* schema_vie case NANOARROW_TYPE_FIXED_SIZE_LIST: return ArrowSchemaViewValidateNChildren(schema_view, 1, error); + case NANOARROW_TYPE_RUN_END_ENCODED: + return ArrowSchemaViewValidateNChildren(schema_view, 2, error); + case NANOARROW_TYPE_STRUCT: return ArrowSchemaViewValidateNChildren(schema_view, -1, error); @@ -1551,7 +1671,7 @@ static ArrowErrorCode ArrowSchemaViewValidate(struct ArrowSchemaView* schema_vie default: ArrowErrorSet(error, "Expected a valid enum ArrowType value but found %d", - (int)schema_view->type); + schema_view->type); return EINVAL; } @@ -1601,8 +1721,8 @@ ArrowErrorCode ArrowSchemaViewInit(struct ArrowSchemaView* schema_view, } if ((format + format_len) != format_end_out) { - ArrowErrorSet(error, "Error parsing schema->format '%s': parsed %d/%d characters", - format, (int)(format_end_out - format), (int)(format_len)); + ArrowErrorSet(error, "Error parsing schema->format '%s': parsed %d/%zu characters", + format, (int)(format_end_out - format), format_len); return EINVAL; } @@ -1662,9 +1782,8 @@ static int64_t ArrowSchemaTypeToStringInternal(struct ArrowSchemaView* schema_vi switch (schema_view->type) { case NANOARROW_TYPE_DECIMAL128: case NANOARROW_TYPE_DECIMAL256: - return snprintf(out, n, "%s(%d, %d)", type_string, - (int)schema_view->decimal_precision, - (int)schema_view->decimal_scale); + return snprintf(out, n, "%s(%" PRId32 ", %" PRId32 ")", type_string, + schema_view->decimal_precision, schema_view->decimal_scale); case NANOARROW_TYPE_TIMESTAMP: return snprintf(out, n, "%s('%s', '%s')", type_string, ArrowTimeUnitString(schema_view->time_unit), schema_view->timezone); @@ -1675,7 +1794,7 @@ static int64_t ArrowSchemaTypeToStringInternal(struct ArrowSchemaView* schema_vi ArrowTimeUnitString(schema_view->time_unit)); case NANOARROW_TYPE_FIXED_SIZE_BINARY: case NANOARROW_TYPE_FIXED_SIZE_LIST: - return snprintf(out, n, "%s(%ld)", type_string, (long)schema_view->fixed_size); + return snprintf(out, n, "%s(%" PRId32 ")", type_string, schema_view->fixed_size); case NANOARROW_TYPE_SPARSE_UNION: case NANOARROW_TYPE_DENSE_UNION: return snprintf(out, n, "%s([%s])", type_string, schema_view->union_type_ids); @@ -1688,6 +1807,12 @@ static int64_t ArrowSchemaTypeToStringInternal(struct ArrowSchemaView* schema_vi // among multiple sprintf calls. static inline void ArrowToStringLogChars(char** out, int64_t n_chars_last, int64_t* n_remaining, int64_t* n_chars) { + // In the unlikely snprintf() returning a negative value (encoding error), + // ensure the result won't cause an out-of-bounds access. + if (n_chars_last < 0) { + n_chars_last = 0; + } + *n_chars += n_chars_last; *n_remaining -= n_chars_last; @@ -1783,7 +1908,12 @@ int64_t ArrowSchemaToString(const struct ArrowSchema* schema, char* out, int64_t n_chars += snprintf(out, n, ">"); } - return n_chars; + // Ensure that we always return a positive result + if (n_chars > 0) { + return n_chars; + } else { + return 0; + } } ArrowErrorCode ArrowMetadataReaderInit(struct ArrowMetadataReader* reader, @@ -2019,6 +2149,10 @@ ArrowErrorCode ArrowMetadataBuilderRemove(struct ArrowBuffer* buffer, // under the License. #include +#include +#include +#include +#include #include #include @@ -2032,6 +2166,12 @@ static void ArrowArrayReleaseInternal(struct ArrowArray* array) { ArrowBitmapReset(&private_data->bitmap); ArrowBufferReset(&private_data->buffers[0]); ArrowBufferReset(&private_data->buffers[1]); + ArrowFree(private_data->buffer_data); + for (int32_t i = 0; i < private_data->n_variadic_buffers; ++i) { + ArrowBufferReset(&private_data->variadic_buffers[i]); + } + ArrowFree(private_data->variadic_buffers); + ArrowFree(private_data->variadic_buffer_sizes); ArrowFree(private_data); } @@ -2072,6 +2212,7 @@ static ArrowErrorCode ArrowArraySetStorageType(struct ArrowArray* array, switch (storage_type) { case NANOARROW_TYPE_UNINITIALIZED: case NANOARROW_TYPE_NA: + case NANOARROW_TYPE_RUN_END_ENCODED: array->n_buffers = 0; break; @@ -2105,7 +2246,10 @@ static ArrowErrorCode ArrowArraySetStorageType(struct ArrowArray* array, case NANOARROW_TYPE_DENSE_UNION: array->n_buffers = 2; break; - + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + array->n_buffers = NANOARROW_BINARY_VIEW_FIXED_BUFFERS + 1; + break; case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: case NANOARROW_TYPE_BINARY: @@ -2148,12 +2292,36 @@ ArrowErrorCode ArrowArrayInitFromType(struct ArrowArray* array, ArrowBitmapInit(&private_data->bitmap); ArrowBufferInit(&private_data->buffers[0]); ArrowBufferInit(&private_data->buffers[1]); - private_data->buffer_data[0] = NULL; - private_data->buffer_data[1] = NULL; - private_data->buffer_data[2] = NULL; + private_data->buffer_data = + (const void**)ArrowMalloc(sizeof(void*) * NANOARROW_MAX_FIXED_BUFFERS); + for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; ++i) { + private_data->buffer_data[i] = NULL; + } + private_data->n_variadic_buffers = 0; + private_data->variadic_buffers = NULL; + private_data->variadic_buffer_sizes = NULL; array->private_data = private_data; - array->buffers = (const void**)(&private_data->buffer_data); + array->buffers = (const void**)(private_data->buffer_data); + + // These are not technically "storage" in the sense that they do not appear + // in the ArrowSchemaView's storage_type member; however, allowing them here + // is helpful to maximize the number of types that can avoid going through + // ArrowArrayInitFromSchema(). + switch (storage_type) { + case NANOARROW_TYPE_DURATION: + case NANOARROW_TYPE_TIMESTAMP: + case NANOARROW_TYPE_TIME64: + case NANOARROW_TYPE_DATE64: + storage_type = NANOARROW_TYPE_INT64; + break; + case NANOARROW_TYPE_TIME32: + case NANOARROW_TYPE_DATE32: + storage_type = NANOARROW_TYPE_INT32; + break; + default: + break; + } int result = ArrowArraySetStorageType(array, storage_type); if (result != NANOARROW_OK) { @@ -2410,19 +2578,16 @@ static ArrowErrorCode ArrowArrayFinalizeBuffers(struct ArrowArray* array) { struct ArrowArrayPrivateData* private_data = (struct ArrowArrayPrivateData*)array->private_data; - // The only buffer finalizing this currently does is make sure the data - // buffer for (Large)String|Binary is never NULL - switch (private_data->storage_type) { - case NANOARROW_TYPE_BINARY: - case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_BINARY: - case NANOARROW_TYPE_LARGE_STRING: - if (ArrowArrayBuffer(array, 2)->data == NULL) { - NANOARROW_RETURN_NOT_OK(ArrowBufferAppendUInt8(ArrowArrayBuffer(array, 2), 0)); - } - break; - default: - break; + for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { + if (private_data->layout.buffer_type[i] == NANOARROW_BUFFER_TYPE_VALIDITY || + private_data->layout.buffer_type[i] == NANOARROW_BUFFER_TYPE_NONE) { + continue; + } + + struct ArrowBuffer* buffer = ArrowArrayBuffer(array, i); + if (buffer->data == NULL) { + NANOARROW_RETURN_NOT_OK((ArrowBufferReserve(buffer, 1))); + } } for (int64_t i = 0; i < array->n_children; i++) { @@ -2440,10 +2605,26 @@ static void ArrowArrayFlushInternalPointers(struct ArrowArray* array) { struct ArrowArrayPrivateData* private_data = (struct ArrowArrayPrivateData*)array->private_data; - for (int64_t i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { + const bool is_binary_view = private_data->storage_type == NANOARROW_TYPE_STRING_VIEW || + private_data->storage_type == NANOARROW_TYPE_BINARY_VIEW; + const int32_t nfixed_buf = is_binary_view ? 2 : NANOARROW_MAX_FIXED_BUFFERS; + + for (int32_t i = 0; i < nfixed_buf; i++) { private_data->buffer_data[i] = ArrowArrayBuffer(array, i)->data; } + if (is_binary_view) { + const int32_t nvirt_buf = private_data->n_variadic_buffers; + private_data->buffer_data = (const void**)ArrowRealloc( + private_data->buffer_data, sizeof(void*) * (nfixed_buf + nvirt_buf + 1)); + for (int32_t i = 0; i < nvirt_buf; i++) { + private_data->buffer_data[nfixed_buf + i] = private_data->variadic_buffers[i].data; + } + private_data->buffer_data[nfixed_buf + nvirt_buf] = + private_data->variadic_buffer_sizes; + array->buffers = (const void**)(private_data->buffer_data); + } + for (int64_t i = 0; i < array->n_children; i++) { ArrowArrayFlushInternalPointers(array->children[i]); } @@ -2458,7 +2639,8 @@ ArrowErrorCode ArrowArrayFinishBuilding(struct ArrowArray* array, struct ArrowError* error) { // Even if the data buffer is size zero, the pointer value needed to be non-null // in some implementations (at least one version of Arrow C++ at the time this - // was added). Only do this fix if we can assume CPU data access. + // was added and C# as later discovered). Only do this fix if we can assume + // CPU data access. if (validation_level >= NANOARROW_VALIDATION_LEVEL_DEFAULT) { NANOARROW_RETURN_NOT_OK_WITH_ERROR(ArrowArrayFinalizeBuffers(array), error); } @@ -2498,6 +2680,11 @@ ArrowErrorCode ArrowArrayViewAllocateChildren(struct ArrowArrayView* array_view, return EINVAL; } + if (n_children == 0) { + array_view->n_children = 0; + return NANOARROW_OK; + } + array_view->children = (struct ArrowArrayView**)ArrowMalloc(n_children * sizeof(struct ArrowArrayView*)); if (array_view->children == NULL) { @@ -2646,6 +2833,8 @@ void ArrowArrayViewSetLength(struct ArrowArrayView* array_view, int64_t length) case NANOARROW_BUFFER_TYPE_UNION_OFFSET: array_view->buffer_views[i].size_bytes = element_size_bytes * length; continue; + case NANOARROW_BUFFER_TYPE_VARIADIC_DATA: + case NANOARROW_BUFFER_TYPE_VARIADIC_SIZE: case NANOARROW_BUFFER_TYPE_NONE: array_view->buffer_views[i].size_bytes = 0; continue; @@ -2678,9 +2867,16 @@ static int ArrowArrayViewSetArrayInternal(struct ArrowArrayView* array_view, array_view->offset = array->offset; array_view->length = array->length; array_view->null_count = array->null_count; + array_view->variadic_buffer_sizes = NULL; + array_view->variadic_buffers = NULL; + array_view->n_variadic_buffers = 0; int64_t buffers_required = 0; - for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { + const int nfixed_buf = array_view->storage_type == NANOARROW_TYPE_STRING_VIEW || + array_view->storage_type == NANOARROW_TYPE_BINARY_VIEW + ? NANOARROW_BINARY_VIEW_FIXED_BUFFERS + : NANOARROW_MAX_FIXED_BUFFERS; + for (int i = 0; i < nfixed_buf; i++) { if (array_view->layout.buffer_type[i] == NANOARROW_BUFFER_TYPE_NONE) { break; } @@ -2698,17 +2894,30 @@ static int ArrowArrayViewSetArrayInternal(struct ArrowArrayView* array_view, } } - // Check the number of buffers + if (array_view->storage_type == NANOARROW_TYPE_STRING_VIEW || + array_view->storage_type == NANOARROW_TYPE_BINARY_VIEW) { + const int64_t n_buffers = array->n_buffers; + const int32_t nfixed_buf = NANOARROW_BINARY_VIEW_FIXED_BUFFERS; + + const int32_t nvariadic_buf = (int32_t)(n_buffers - nfixed_buf - 1); + array_view->n_variadic_buffers = nvariadic_buf; + buffers_required += nvariadic_buf + 1; + array_view->variadic_buffers = array->buffers + NANOARROW_BINARY_VIEW_FIXED_BUFFERS; + array_view->variadic_buffer_sizes = (int64_t*)array->buffers[n_buffers - 1]; + } + if (buffers_required != array->n_buffers) { - ArrowErrorSet(error, "Expected array with %d buffer(s) but found %d buffer(s)", - (int)buffers_required, (int)array->n_buffers); + ArrowErrorSet(error, + "Expected array with %" PRId64 " buffer(s) but found %" PRId64 + " buffer(s)", + buffers_required, array->n_buffers); return EINVAL; } // Check number of children if (array_view->n_children != array->n_children) { - ArrowErrorSet(error, "Expected %ld children but found %ld children", - (long)array_view->n_children, (long)array->n_children); + ArrowErrorSet(error, "Expected %" PRId64 " children but found %" PRId64 " children", + array_view->n_children, array->n_children); return EINVAL; } @@ -2740,14 +2949,20 @@ static int ArrowArrayViewSetArrayInternal(struct ArrowArrayView* array_view, static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, struct ArrowError* error) { if (array_view->length < 0) { - ArrowErrorSet(error, "Expected length >= 0 but found length %ld", - (long)array_view->length); + ArrowErrorSet(error, "Expected length >= 0 but found length %" PRId64, + array_view->length); return EINVAL; } if (array_view->offset < 0) { - ArrowErrorSet(error, "Expected offset >= 0 but found offset %ld", - (long)array_view->offset); + ArrowErrorSet(error, "Expected offset >= 0 but found offset %" PRId64, + array_view->offset); + return EINVAL; + } + + // Ensure that offset + length fits within an int64 before a possible overflow + if ((uint64_t)array_view->offset + (uint64_t)array_view->length > (uint64_t)INT64_MAX) { + ArrowErrorSet(error, "Offset + length is > INT64_MAX"); return EINVAL; } @@ -2760,7 +2975,9 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, for (int i = 0; i < 2; i++) { int64_t element_size_bytes = array_view->layout.element_size_bits[i] / 8; // Initialize with a value that will cause an error if accidentally used uninitialized - int64_t min_buffer_size_bytes = array_view->buffer_views[i].size_bytes + 1; + // Need to suppress the clang-tidy warning because gcc warns for possible use + int64_t min_buffer_size_bytes = // NOLINT(clang-analyzer-deadcode.DeadStores) + array_view->buffer_views[i].size_bytes + 1; switch (array_view->layout.buffer_type[i]) { case NANOARROW_BUFFER_TYPE_VALIDITY: @@ -2786,6 +3003,8 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, case NANOARROW_BUFFER_TYPE_UNION_OFFSET: min_buffer_size_bytes = element_size_bytes * offset_plus_length; break; + case NANOARROW_BUFFER_TYPE_VARIADIC_DATA: + case NANOARROW_BUFFER_TYPE_VARIADIC_SIZE: case NANOARROW_BUFFER_TYPE_NONE: continue; } @@ -2795,11 +3014,11 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, array_view->buffer_views[i].size_bytes = min_buffer_size_bytes; } else if (array_view->buffer_views[i].size_bytes < min_buffer_size_bytes) { ArrowErrorSet(error, - "Expected %s array buffer %d to have size >= %ld bytes but found " - "buffer with %ld bytes", - ArrowTypeString(array_view->storage_type), (int)i, - (long)min_buffer_size_bytes, - (long)array_view->buffer_views[i].size_bytes); + "Expected %s array buffer %d to have size >= %" PRId64 + " bytes but found " + "buffer with %" PRId64 " bytes", + ArrowTypeString(array_view->storage_type), i, min_buffer_size_bytes, + array_view->buffer_views[i].size_bytes); return EINVAL; } } @@ -2811,11 +3030,20 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, case NANOARROW_TYPE_FIXED_SIZE_LIST: case NANOARROW_TYPE_MAP: if (array_view->n_children != 1) { - ArrowErrorSet(error, "Expected 1 child of %s array but found %ld child arrays", - ArrowTypeString(array_view->storage_type), - (long)array_view->n_children); + ArrowErrorSet(error, + "Expected 1 child of %s array but found %" PRId64 " child arrays", + ArrowTypeString(array_view->storage_type), array_view->n_children); + return EINVAL; + } + break; + case NANOARROW_TYPE_RUN_END_ENCODED: + if (array_view->n_children != 2) { + ArrowErrorSet( + error, "Expected 2 children for %s array but found %" PRId64 " child arrays", + ArrowTypeString(array_view->storage_type), array_view->n_children); return EINVAL; } + break; default: break; } @@ -2829,12 +3057,11 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, child_min_length = (array_view->offset + array_view->length); for (int64_t i = 0; i < array_view->n_children; i++) { if (array_view->children[i]->length < child_min_length) { - ArrowErrorSet( - error, - "Expected struct child %d to have length >= %ld but found child with " - "length %ld", - (int)(i + 1), (long)(child_min_length), - (long)array_view->children[i]->length); + ArrowErrorSet(error, + "Expected struct child %" PRId64 " to have length >= %" PRId64 + " but found child with " + "length %" PRId64, + i + 1, child_min_length, array_view->children[i]->length); return EINVAL; } } @@ -2845,12 +3072,78 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, array_view->layout.child_size_elements; if (array_view->children[0]->length < child_min_length) { ArrowErrorSet(error, - "Expected child of fixed_size_list array to have length >= %ld but " - "found array with length %ld", - (long)child_min_length, (long)array_view->children[0]->length); + "Expected child of fixed_size_list array to have length >= %" PRId64 + " but " + "found array with length %" PRId64, + child_min_length, array_view->children[0]->length); + return EINVAL; + } + break; + + case NANOARROW_TYPE_RUN_END_ENCODED: { + if (array_view->n_children != 2) { + ArrowErrorSet(error, + "Expected 2 children for run-end encoded array but found %" PRId64, + array_view->n_children); + return EINVAL; + } + struct ArrowArrayView* run_ends_view = array_view->children[0]; + struct ArrowArrayView* values_view = array_view->children[1]; + int64_t max_length; + switch (run_ends_view->storage_type) { + case NANOARROW_TYPE_INT16: + max_length = INT16_MAX; + break; + case NANOARROW_TYPE_INT32: + max_length = INT32_MAX; + break; + case NANOARROW_TYPE_INT64: + max_length = INT64_MAX; + break; + default: + ArrowErrorSet( + error, + "Run-end encoded array only supports INT16, INT32 or INT64 run-ends " + "but found run-ends type %s", + ArrowTypeString(run_ends_view->storage_type)); + return EINVAL; + } + + // There is already a check above that offset_plus_length < INT64_MAX + if (offset_plus_length > max_length) { + ArrowErrorSet(error, + "Offset + length of a run-end encoded array must fit in a value" + " of the run end type %s but is %" PRId64 " + %" PRId64, + ArrowTypeString(run_ends_view->storage_type), array_view->offset, + array_view->length); + return EINVAL; + } + + if (run_ends_view->length > values_view->length) { + ArrowErrorSet(error, + "Length of run_ends is greater than the length of values: %" PRId64 + " > %" PRId64, + run_ends_view->length, values_view->length); + return EINVAL; + } + + if (run_ends_view->length == 0 && values_view->length != 0) { + ArrowErrorSet(error, + "Run-end encoded array has zero length %" PRId64 + ", but values array has " + "non-zero length", + values_view->length); + return EINVAL; + } + + if (run_ends_view->null_count != 0) { + ArrowErrorSet(error, "Null count must be 0 for run ends array, but is %" PRId64, + run_ends_view->null_count); return EINVAL; } break; + } + default: break; } @@ -2886,64 +3179,83 @@ static int ArrowArrayViewValidateDefault(struct ArrowArrayView* array_view, case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_BINARY: if (array_view->buffer_views[1].size_bytes != 0) { - first_offset = array_view->buffer_views[1].data.as_int32[0]; + first_offset = array_view->buffer_views[1].data.as_int32[array_view->offset]; if (first_offset < 0) { - ArrowErrorSet(error, "Expected first offset >= 0 but found %ld", - (long)first_offset); + ArrowErrorSet(error, "Expected first offset >= 0 but found %" PRId64, + first_offset); return EINVAL; } last_offset = array_view->buffer_views[1].data.as_int32[offset_plus_length]; + if (last_offset < 0) { + ArrowErrorSet(error, "Expected last offset >= 0 but found %" PRId64, + last_offset); + return EINVAL; + } // If the data buffer size is unknown, assign it; otherwise, check it if (array_view->buffer_views[2].size_bytes == -1) { array_view->buffer_views[2].size_bytes = last_offset; } else if (array_view->buffer_views[2].size_bytes < last_offset) { ArrowErrorSet(error, - "Expected %s array buffer 2 to have size >= %ld bytes but found " - "buffer with %ld bytes", - ArrowTypeString(array_view->storage_type), (long)last_offset, - (long)array_view->buffer_views[2].size_bytes); + "Expected %s array buffer 2 to have size >= %" PRId64 + " bytes but found " + "buffer with %" PRId64 " bytes", + ArrowTypeString(array_view->storage_type), last_offset, + array_view->buffer_views[2].size_bytes); return EINVAL; } + } else if (array_view->buffer_views[2].size_bytes == -1) { + // If the data buffer size is unknown and there are no bytes in the offset buffer, + // set the data buffer size to 0. + array_view->buffer_views[2].size_bytes = 0; } break; case NANOARROW_TYPE_LARGE_STRING: case NANOARROW_TYPE_LARGE_BINARY: if (array_view->buffer_views[1].size_bytes != 0) { - first_offset = array_view->buffer_views[1].data.as_int64[0]; + first_offset = array_view->buffer_views[1].data.as_int64[array_view->offset]; if (first_offset < 0) { - ArrowErrorSet(error, "Expected first offset >= 0 but found %ld", - (long)first_offset); + ArrowErrorSet(error, "Expected first offset >= 0 but found %" PRId64, + first_offset); return EINVAL; } last_offset = array_view->buffer_views[1].data.as_int64[offset_plus_length]; + if (last_offset < 0) { + ArrowErrorSet(error, "Expected last offset >= 0 but found %" PRId64, + last_offset); + return EINVAL; + } // If the data buffer size is unknown, assign it; otherwise, check it if (array_view->buffer_views[2].size_bytes == -1) { array_view->buffer_views[2].size_bytes = last_offset; } else if (array_view->buffer_views[2].size_bytes < last_offset) { ArrowErrorSet(error, - "Expected %s array buffer 2 to have size >= %ld bytes but found " - "buffer with %ld bytes", - ArrowTypeString(array_view->storage_type), (long)last_offset, - (long)array_view->buffer_views[2].size_bytes); + "Expected %s array buffer 2 to have size >= %" PRId64 + " bytes but found " + "buffer with %" PRId64 " bytes", + ArrowTypeString(array_view->storage_type), last_offset, + array_view->buffer_views[2].size_bytes); return EINVAL; } + } else if (array_view->buffer_views[2].size_bytes == -1) { + // If the data buffer size is unknown and there are no bytes in the offset + // buffer, set the data buffer size to 0. + array_view->buffer_views[2].size_bytes = 0; } break; case NANOARROW_TYPE_STRUCT: for (int64_t i = 0; i < array_view->n_children; i++) { if (array_view->children[i]->length < offset_plus_length) { - ArrowErrorSet( - error, - "Expected struct child %d to have length >= %ld but found child with " - "length %ld", - (int)(i + 1), (long)offset_plus_length, - (long)array_view->children[i]->length); + ArrowErrorSet(error, + "Expected struct child %" PRId64 " to have length >= %" PRId64 + " but found child with " + "length %" PRId64, + i + 1, offset_plus_length, array_view->children[i]->length); return EINVAL; } } @@ -2952,21 +3264,27 @@ static int ArrowArrayViewValidateDefault(struct ArrowArrayView* array_view, case NANOARROW_TYPE_LIST: case NANOARROW_TYPE_MAP: if (array_view->buffer_views[1].size_bytes != 0) { - first_offset = array_view->buffer_views[1].data.as_int32[0]; + first_offset = array_view->buffer_views[1].data.as_int32[array_view->offset]; if (first_offset < 0) { - ArrowErrorSet(error, "Expected first offset >= 0 but found %ld", - (long)first_offset); + ArrowErrorSet(error, "Expected first offset >= 0 but found %" PRId64, + first_offset); return EINVAL; } last_offset = array_view->buffer_views[1].data.as_int32[offset_plus_length]; + if (last_offset < 0) { + ArrowErrorSet(error, "Expected last offset >= 0 but found %" PRId64, + last_offset); + return EINVAL; + } + if (array_view->children[0]->length < last_offset) { - ArrowErrorSet( - error, - "Expected child of %s array to have length >= %ld but found array with " - "length %ld", - ArrowTypeString(array_view->storage_type), (long)last_offset, - (long)array_view->children[0]->length); + ArrowErrorSet(error, + "Expected child of %s array to have length >= %" PRId64 + " but found array with " + "length %" PRId64, + ArrowTypeString(array_view->storage_type), last_offset, + array_view->children[0]->length); return EINVAL; } } @@ -2974,24 +3292,58 @@ static int ArrowArrayViewValidateDefault(struct ArrowArrayView* array_view, case NANOARROW_TYPE_LARGE_LIST: if (array_view->buffer_views[1].size_bytes != 0) { - first_offset = array_view->buffer_views[1].data.as_int64[0]; + first_offset = array_view->buffer_views[1].data.as_int64[array_view->offset]; if (first_offset < 0) { - ArrowErrorSet(error, "Expected first offset >= 0 but found %ld", - (long)first_offset); + ArrowErrorSet(error, "Expected first offset >= 0 but found %" PRId64, + first_offset); return EINVAL; } last_offset = array_view->buffer_views[1].data.as_int64[offset_plus_length]; + if (last_offset < 0) { + ArrowErrorSet(error, "Expected last offset >= 0 but found %" PRId64, + last_offset); + return EINVAL; + } + if (array_view->children[0]->length < last_offset) { - ArrowErrorSet( - error, - "Expected child of large list array to have length >= %ld but found array " - "with length %ld", - (long)last_offset, (long)array_view->children[0]->length); + ArrowErrorSet(error, + "Expected child of large list array to have length >= %" PRId64 + " but found array " + "with length %" PRId64, + last_offset, array_view->children[0]->length); return EINVAL; } } break; + + case NANOARROW_TYPE_RUN_END_ENCODED: { + struct ArrowArrayView* run_ends_view = array_view->children[0]; + if (run_ends_view->length == 0) { + break; + } + + int64_t first_run_end = ArrowArrayViewGetIntUnsafe(run_ends_view, 0); + if (first_run_end < 1) { + ArrowErrorSet( + error, + "All run ends must be greater than 0 but the first run end is %" PRId64, + first_run_end); + return EINVAL; + } + + // offset + length < INT64_MAX is checked in ArrowArrayViewValidateMinimal() + int64_t last_run_end = + ArrowArrayViewGetIntUnsafe(run_ends_view, run_ends_view->length - 1); + if (last_run_end < offset_plus_length) { + ArrowErrorSet(error, + "Last run end is %" PRId64 " but it should be >= (%" PRId64 + " + %" PRId64 ")", + last_run_end, array_view->offset, array_view->length); + return EINVAL; + } + break; + } default: break; } @@ -3044,7 +3396,7 @@ static int ArrowAssertIncreasingInt32(struct ArrowBufferView view, for (int64_t i = 1; i < view.size_bytes / (int64_t)sizeof(int32_t); i++) { if (view.data.as_int32[i] < view.data.as_int32[i - 1]) { - ArrowErrorSet(error, "[%ld] Expected element size >= 0", (long)i); + ArrowErrorSet(error, "[%" PRId64 "] Expected element size >= 0", i); return EINVAL; } } @@ -3060,7 +3412,7 @@ static int ArrowAssertIncreasingInt64(struct ArrowBufferView view, for (int64_t i = 1; i < view.size_bytes / (int64_t)sizeof(int64_t); i++) { if (view.data.as_int64[i] < view.data.as_int64[i - 1]) { - ArrowErrorSet(error, "[%ld] Expected element size >= 0", (long)i); + ArrowErrorSet(error, "[%" PRId64 "] Expected element size >= 0", i); return EINVAL; } } @@ -3073,8 +3425,9 @@ static int ArrowAssertRangeInt8(struct ArrowBufferView view, int8_t min_value, for (int64_t i = 0; i < view.size_bytes; i++) { if (view.data.as_int8[i] < min_value || view.data.as_int8[i] > max_value) { ArrowErrorSet(error, - "[%ld] Expected buffer value between %d and %d but found value %d", - (long)i, (int)min_value, (int)max_value, (int)view.data.as_int8[i]); + "[%" PRId64 "] Expected buffer value between %" PRId8 " and %" PRId8 + " but found value %" PRId8, + i, min_value, max_value, view.data.as_int8[i]); return EINVAL; } } @@ -3094,8 +3447,8 @@ static int ArrowAssertInt8In(struct ArrowBufferView view, const int8_t* values, } if (!item_found) { - ArrowErrorSet(error, "[%ld] Unexpected buffer value %d", (long)i, - (int)view.data.as_int8[i]); + ArrowErrorSet(error, "[%" PRId64 "] Unexpected buffer value %" PRId8, i, + view.data.as_int8[i]); return EINVAL; } } @@ -3107,13 +3460,24 @@ static int ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, struct ArrowError* error) { for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { switch (array_view->layout.buffer_type[i]) { + // Only validate the portion of the buffer that is strictly required, + // which includes not validating the offset buffer of a zero-length array. case NANOARROW_BUFFER_TYPE_DATA_OFFSET: + if (array_view->length == 0) { + continue; + } if (array_view->layout.element_size_bits[i] == 32) { - NANOARROW_RETURN_NOT_OK( - ArrowAssertIncreasingInt32(array_view->buffer_views[i], error)); + struct ArrowBufferView sliced_offsets; + sliced_offsets.data.as_int32 = + array_view->buffer_views[i].data.as_int32 + array_view->offset; + sliced_offsets.size_bytes = (array_view->length + 1) * sizeof(int32_t); + NANOARROW_RETURN_NOT_OK(ArrowAssertIncreasingInt32(sliced_offsets, error)); } else { - NANOARROW_RETURN_NOT_OK( - ArrowAssertIncreasingInt64(array_view->buffer_views[i], error)); + struct ArrowBufferView sliced_offsets; + sliced_offsets.data.as_int64 = + array_view->buffer_views[i].data.as_int64 + array_view->offset; + sliced_offsets.size_bytes = (array_view->length + 1) * sizeof(int64_t); + NANOARROW_RETURN_NOT_OK(ArrowAssertIncreasingInt64(sliced_offsets, error)); } break; default: @@ -3123,6 +3487,15 @@ static int ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, if (array_view->storage_type == NANOARROW_TYPE_DENSE_UNION || array_view->storage_type == NANOARROW_TYPE_SPARSE_UNION) { + struct ArrowBufferView sliced_type_ids; + sliced_type_ids.size_bytes = array_view->length * sizeof(int8_t); + if (array_view->length > 0) { + sliced_type_ids.data.as_int8 = + array_view->buffer_views[0].data.as_int8 + array_view->offset; + } else { + sliced_type_ids.data.as_int8 = NULL; + } + if (array_view->union_type_id_map == NULL) { // If the union_type_id map is NULL (e.g., when using ArrowArrayInitFromType() + // ArrowArrayAllocateChildren() + ArrowArrayFinishBuilding()), we don't have enough @@ -3134,9 +3507,9 @@ static int ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, array_view->union_type_id_map, array_view->n_children, array_view->n_children)) { NANOARROW_RETURN_NOT_OK(ArrowAssertRangeInt8( - array_view->buffer_views[0], 0, (int8_t)(array_view->n_children - 1), error)); + sliced_type_ids, 0, (int8_t)(array_view->n_children - 1), error)); } else { - NANOARROW_RETURN_NOT_OK(ArrowAssertInt8In(array_view->buffer_views[0], + NANOARROW_RETURN_NOT_OK(ArrowAssertInt8In(sliced_type_ids, array_view->union_type_id_map + 128, array_view->n_children, error)); } @@ -3150,16 +3523,37 @@ static int ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, int64_t offset = ArrowArrayViewUnionChildOffset(array_view, i); int64_t child_length = array_view->children[child_id]->length; if (offset < 0 || offset > child_length) { - ArrowErrorSet( - error, - "[%ld] Expected union offset for child id %d to be between 0 and %ld but " - "found offset value %ld", - (long)i, (int)child_id, (long)child_length, (long)offset); + ArrowErrorSet(error, + "[%" PRId64 "] Expected union offset for child id %" PRId8 + " to be between 0 and %" PRId64 + " but " + "found offset value %" PRId64, + i, child_id, child_length, offset); return EINVAL; } } } + if (array_view->storage_type == NANOARROW_TYPE_RUN_END_ENCODED) { + struct ArrowArrayView* run_ends_view = array_view->children[0]; + if (run_ends_view->length > 0) { + int64_t last_run_end = ArrowArrayViewGetIntUnsafe(run_ends_view, 0); + for (int64_t i = 1; i < run_ends_view->length; i++) { + const int64_t run_end = ArrowArrayViewGetIntUnsafe(run_ends_view, i); + if (run_end <= last_run_end) { + ArrowErrorSet( + error, + "Every run end must be strictly greater than the previous run end, " + "but run_ends[%" PRId64 " is %" PRId64 " and run_ends[%" PRId64 + "] is %" PRId64, + i, run_end, i - 1, last_run_end); + return EINVAL; + } + last_run_end = run_end; + } + } + } + // Recurse for children for (int64_t i = 0; i < array_view->n_children; i++) { NANOARROW_RETURN_NOT_OK(ArrowArrayViewValidateFull(array_view->children[i], error)); @@ -3192,6 +3586,136 @@ ArrowErrorCode ArrowArrayViewValidate(struct ArrowArrayView* array_view, ArrowErrorSet(error, "validation_level not recognized"); return EINVAL; } + +struct ArrowComparisonInternalState { + enum ArrowCompareLevel level; + int is_equal; + struct ArrowError* reason; +}; + +NANOARROW_CHECK_PRINTF_ATTRIBUTE static void ArrowComparePrependPath( + struct ArrowError* out, const char* fmt, ...) { + if (out == NULL) { + return; + } + + char prefix[128]; + prefix[0] = '\0'; + va_list args; + va_start(args, fmt); + int prefix_len = vsnprintf(prefix, sizeof(prefix), fmt, args); + va_end(args); + + if (prefix_len <= 0) { + return; + } + + size_t out_len = strlen(out->message); + size_t out_len_to_move = sizeof(struct ArrowError) - prefix_len - 1; + if (out_len_to_move > out_len) { + out_len_to_move = out_len; + } + + memmove(out->message + prefix_len, out->message, out_len_to_move); + memcpy(out->message, prefix, prefix_len); + out->message[out_len + prefix_len] = '\0'; +} + +#define SET_NOT_EQUAL_AND_RETURN_IF_IMPL(cond_, state_, reason_) \ + do { \ + if (cond_) { \ + ArrowErrorSet(state_->reason, ": %s", reason_); \ + state_->is_equal = 0; \ + return; \ + } \ + } while (0) + +#define SET_NOT_EQUAL_AND_RETURN_IF(condition_, state_) \ + SET_NOT_EQUAL_AND_RETURN_IF_IMPL(condition_, state_, #condition_) + +static void ArrowArrayViewCompareBuffer(const struct ArrowArrayView* actual, + const struct ArrowArrayView* expected, int i, + struct ArrowComparisonInternalState* state) { + SET_NOT_EQUAL_AND_RETURN_IF( + actual->buffer_views[i].size_bytes != expected->buffer_views[i].size_bytes, state); + + int64_t buffer_size = actual->buffer_views[i].size_bytes; + if (buffer_size > 0) { + SET_NOT_EQUAL_AND_RETURN_IF( + memcmp(actual->buffer_views[i].data.data, expected->buffer_views[i].data.data, + buffer_size) != 0, + state); + } +} + +static void ArrowArrayViewCompareIdentical(const struct ArrowArrayView* actual, + const struct ArrowArrayView* expected, + struct ArrowComparisonInternalState* state) { + SET_NOT_EQUAL_AND_RETURN_IF(actual->storage_type != expected->storage_type, state); + SET_NOT_EQUAL_AND_RETURN_IF(actual->n_children != expected->n_children, state); + SET_NOT_EQUAL_AND_RETURN_IF(actual->dictionary == NULL && expected->dictionary != NULL, + state); + SET_NOT_EQUAL_AND_RETURN_IF(actual->dictionary != NULL && expected->dictionary == NULL, + state); + + SET_NOT_EQUAL_AND_RETURN_IF(actual->length != expected->length, state); + SET_NOT_EQUAL_AND_RETURN_IF(actual->offset != expected->offset, state); + SET_NOT_EQUAL_AND_RETURN_IF(actual->null_count != expected->null_count, state); + + for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { + ArrowArrayViewCompareBuffer(actual, expected, i, state); + if (!state->is_equal) { + ArrowComparePrependPath(state->reason, ".buffers[%d]", i); + return; + } + } + + for (int64_t i = 0; i < actual->n_children; i++) { + ArrowArrayViewCompareIdentical(actual->children[i], expected->children[i], state); + if (!state->is_equal) { + ArrowComparePrependPath(state->reason, ".children[%" PRId64 "]", i); + return; + } + } + + if (actual->dictionary != NULL) { + ArrowArrayViewCompareIdentical(actual->dictionary, expected->dictionary, state); + if (!state->is_equal) { + ArrowComparePrependPath(state->reason, ".dictionary"); + return; + } + } +} + +// Top-level entry point to take care of creating, cleaning up, and +// propagating the ArrowComparisonInternalState to the caller +ArrowErrorCode ArrowArrayViewCompare(const struct ArrowArrayView* actual, + const struct ArrowArrayView* expected, + enum ArrowCompareLevel level, int* out, + struct ArrowError* reason) { + struct ArrowComparisonInternalState state; + state.level = level; + state.is_equal = 1; + state.reason = reason; + + switch (level) { + case NANOARROW_COMPARE_IDENTICAL: + ArrowArrayViewCompareIdentical(actual, expected, &state); + break; + default: + return EINVAL; + } + + *out = state.is_equal; + if (!state.is_equal) { + ArrowComparePrependPath(state.reason, "root"); + } + + return NANOARROW_OK; +} + +#undef SET_NOT_EQUAL_AND_RETURN_IF +#undef SET_NOT_EQUAL_AND_RETURN_IF_IMPL // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information diff --git a/c/vendor/nanoarrow/nanoarrow.h b/c/vendor/nanoarrow/nanoarrow.h index 82aa4b0d10..264aad5b6e 100644 --- a/c/vendor/nanoarrow/nanoarrow.h +++ b/c/vendor/nanoarrow/nanoarrow.h @@ -19,9 +19,9 @@ #define NANOARROW_BUILD_ID_H_INCLUDED #define NANOARROW_VERSION_MAJOR 0 -#define NANOARROW_VERSION_MINOR 4 +#define NANOARROW_VERSION_MINOR 6 #define NANOARROW_VERSION_PATCH 0 -#define NANOARROW_VERSION "0.4.0" +#define NANOARROW_VERSION "0.6.0" #define NANOARROW_VERSION_INT \ (NANOARROW_VERSION_MAJOR * 10000 + NANOARROW_VERSION_MINOR * 100 + \ @@ -181,14 +181,14 @@ struct ArrowArrayStream { NANOARROW_RETURN_NOT_OK((x_ <= max_) ? NANOARROW_OK : EINVAL) #if defined(NANOARROW_DEBUG) -#define _NANOARROW_RETURN_NOT_OK_WITH_ERROR_IMPL(NAME, EXPR, ERROR_PTR_EXPR, EXPR_STR) \ - do { \ - const int NAME = (EXPR); \ - if (NAME) { \ - ArrowErrorSet((ERROR_PTR_EXPR), "%s failed with errno %d\n* %s:%d", EXPR_STR, \ - NAME, __FILE__, __LINE__); \ - return NAME; \ - } \ +#define _NANOARROW_RETURN_NOT_OK_WITH_ERROR_IMPL(NAME, EXPR, ERROR_PTR_EXPR, EXPR_STR) \ + do { \ + const int NAME = (EXPR); \ + if (NAME) { \ + ArrowErrorSet((ERROR_PTR_EXPR), "%s failed with errno %d(%s)\n* %s:%d", EXPR_STR, \ + NAME, strerror(NAME), __FILE__, __LINE__); \ + return NAME; \ + } \ } while (0) #else #define _NANOARROW_RETURN_NOT_OK_WITH_ERROR_IMPL(NAME, EXPR, ERROR_PTR_EXPR, EXPR_STR) \ @@ -345,7 +345,7 @@ static inline void ArrowErrorSetString(struct ArrowError* error, const char* src #define NANOARROW_DCHECK(EXPR) _NANOARROW_DCHECK_IMPL(EXPR, #EXPR) #else -#define NANOARROW_ASSERT_OK(EXPR) EXPR +#define NANOARROW_ASSERT_OK(EXPR) (void)(EXPR) #define NANOARROW_DCHECK(EXPR) #endif @@ -482,7 +482,10 @@ enum ArrowType { NANOARROW_TYPE_LARGE_STRING, NANOARROW_TYPE_LARGE_BINARY, NANOARROW_TYPE_LARGE_LIST, - NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO + NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO, + NANOARROW_TYPE_RUN_END_ENCODED, + NANOARROW_TYPE_BINARY_VIEW, + NANOARROW_TYPE_STRING_VIEW }; /// \brief Get a string value of an enum ArrowType value @@ -569,6 +572,12 @@ static inline const char* ArrowTypeString(enum ArrowType type) { return "large_list"; case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: return "interval_month_day_nano"; + case NANOARROW_TYPE_RUN_END_ENCODED: + return "run_end_encoded"; + case NANOARROW_TYPE_BINARY_VIEW: + return "binary_view"; + case NANOARROW_TYPE_STRING_VIEW: + return "string_view"; default: return NULL; } @@ -605,6 +614,17 @@ enum ArrowValidationLevel { NANOARROW_VALIDATION_LEVEL_FULL = 3 }; +/// \brief Comparison level enumerator +/// \ingroup nanoarrow-utils +enum ArrowCompareLevel { + /// \brief Consider arrays equal if buffers contain identical content + /// and have identical offset, null count, and length. Note that this is + /// a much stricter check than logical equality, which would take into + /// account potentially different content of null slots, arrays with a + /// non-zero offset, and other considerations. + NANOARROW_COMPARE_IDENTICAL, +}; + /// \brief Get a string value of an enum ArrowTimeUnit value /// \ingroup nanoarrow-utils /// @@ -634,15 +654,13 @@ enum ArrowBufferType { NANOARROW_BUFFER_TYPE_TYPE_ID, NANOARROW_BUFFER_TYPE_UNION_OFFSET, NANOARROW_BUFFER_TYPE_DATA_OFFSET, - NANOARROW_BUFFER_TYPE_DATA + NANOARROW_BUFFER_TYPE_DATA, + NANOARROW_BUFFER_TYPE_VARIADIC_DATA, + NANOARROW_BUFFER_TYPE_VARIADIC_SIZE }; -/// \brief The maximum number of buffers in an ArrowArrayView or ArrowLayout +/// \brief The maximum number of fixed buffers in an ArrowArrayView or ArrowLayout /// \ingroup nanoarrow-array-view -/// -/// All currently supported types have 3 buffers or fewer; however, future types -/// may involve a variable number of buffers (e.g., string view). These buffers -/// will be represented by separate members of the ArrowArrayView or ArrowLayout. #define NANOARROW_MAX_FIXED_BUFFERS 3 /// \brief An non-owning view of a string @@ -689,6 +707,7 @@ union ArrowBufferViewData { const double* as_double; const float* as_float; const char* as_char; + const union ArrowBinaryView* as_binary_view; }; /// \brief An non-owning view of a buffer @@ -721,6 +740,9 @@ struct ArrowBufferAllocator { void* private_data; }; +typedef void (*ArrowBufferDeallocatorCallback)(struct ArrowBufferAllocator* allocator, + uint8_t* ptr, int64_t size); + /// \brief An owning mutable view of a buffer /// \ingroup nanoarrow-buffer struct ArrowBuffer { @@ -823,6 +845,15 @@ struct ArrowArrayView { /// type_id == union_type_id_map[128 + child_index]. This value may be /// NULL in the case where child_id == type_id. int8_t* union_type_id_map; + + /// \brief Number of variadic buffers + int32_t n_variadic_buffers; + + /// \brief Pointers to variadic buffers of binary/string_view arrays + const void** variadic_buffers; + + /// \brief Size of each variadic buffer + int64_t* variadic_buffer_sizes; }; // Used as the private data member for ArrowArrays allocated here and accessed @@ -837,8 +868,8 @@ struct ArrowArrayPrivateData { // The array of pointers to buffers. This must be updated after a sequence // of appends to synchronize its values with the actual buffer addresses - // (which may have ben reallocated uring that time) - const void* buffer_data[NANOARROW_MAX_FIXED_BUFFERS]; + // (which may have been reallocated during that time) + const void** buffer_data; // The storage data type, or NANOARROW_TYPE_UNINITIALIZED if unknown enum ArrowType storage_type; @@ -850,6 +881,15 @@ struct ArrowArrayPrivateData { // In the future this could be replaced with a type id<->child mapping // to support constructing unions in append mode where type_id != child_index int8_t union_type_id_is_child_index; + + // Number of variadic buffers for binary view types + int32_t n_variadic_buffers; + + // Variadic buffers for binary view types + struct ArrowBuffer* variadic_buffers; + + // Size of each variadic buffer in bytes + int64_t* variadic_buffer_sizes; }; /// \brief A representation of an interval. @@ -908,7 +948,7 @@ static inline void ArrowDecimalInit(struct ArrowDecimal* decimal, int32_t bitwid memset(decimal->words, 0, sizeof(decimal->words)); decimal->precision = precision; decimal->scale = scale; - decimal->n_words = bitwidth / 8 / sizeof(uint64_t); + decimal->n_words = (int)(bitwidth / 8 / sizeof(uint64_t)); if (_ArrowIsLittleEndian()) { decimal->low_word_index = 0; @@ -1049,6 +1089,8 @@ static inline void ArrowDecimalSetBytes(struct ArrowDecimal* decimal, NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowSchemaSetTypeFixedSize) #define ArrowSchemaSetTypeDecimal \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowSchemaSetTypeDecimal) +#define ArrowSchemaSetTypeRunEndEncoded \ + NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowSchemaSetTypeRunEndEncoded) #define ArrowSchemaSetTypeDateTime \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowSchemaSetTypeDateTime) #define ArrowSchemaSetTypeUnion \ @@ -1115,6 +1157,7 @@ static inline void ArrowDecimalSetBytes(struct ArrowDecimal* decimal, NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewSetArrayMinimal) #define ArrowArrayViewValidate \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewValidate) +#define ArrowArrayViewCompare NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewCompare) #define ArrowArrayViewReset NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewReset) #define ArrowBasicArrayStreamInit \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowBasicArrayStreamInit) @@ -1168,10 +1211,8 @@ struct ArrowBufferAllocator ArrowBufferAllocatorDefault(void); /// attach a custom deallocator to an ArrowBuffer. This may be used to /// avoid copying an existing buffer that was not allocated using the /// infrastructure provided here (e.g., by an R or Python object). -struct ArrowBufferAllocator ArrowBufferDeallocator( - void (*custom_free)(struct ArrowBufferAllocator* allocator, uint8_t* ptr, - int64_t size), - void* private_data); +struct ArrowBufferAllocator ArrowBufferDeallocator(ArrowBufferDeallocatorCallback, + void* private_data); /// @} @@ -1280,6 +1321,20 @@ ArrowErrorCode ArrowDecimalSetDigits(struct ArrowDecimal* decimal, ArrowErrorCode ArrowDecimalAppendDigitsToBuffer(const struct ArrowDecimal* decimal, struct ArrowBuffer* buffer); +/// \brief Get the half float value of a float +static inline uint16_t ArrowFloatToHalfFloat(float value); + +/// \brief Get the float value of a half float +static inline float ArrowHalfFloatToFloat(uint16_t value); + +/// \brief Resolve a chunk index from increasing int64_t offsets +/// +/// Given a buffer of increasing int64_t offsets that begin with 0 (e.g., offset buffer +/// of a large type, run ends of a chunked array implementation), resolve a value v +/// where lo <= v < hi such that offsets[v] <= index < offsets[v + 1]. +static inline int64_t ArrowResolveChunk64(int64_t index, const int64_t* offsets, + int64_t lo, int64_t hi); + /// @} /// \defgroup nanoarrow-schema Creating schemas @@ -1349,6 +1404,17 @@ ArrowErrorCode ArrowSchemaSetTypeDecimal(struct ArrowSchema* schema, enum ArrowT int32_t decimal_precision, int32_t decimal_scale); +/// \brief Set the format field of a run-end encoded schema +/// +/// Returns EINVAL for run_end_type that is not +/// NANOARROW_TYPE_INT16, NANOARROW_TYPE_INT32 or NANOARROW_TYPE_INT64. +/// Schema must have been initialized using ArrowSchemaInit() or ArrowSchemaDeepCopy(). +/// The caller must call `ArrowSchemaSetTypeXXX(schema->children[1])` to +/// set the value type. Note that when building arrays using the `ArrowArrayAppendXXX()` +/// functions, the run-end encoded array's logical length must be updated manually. +ArrowErrorCode ArrowSchemaSetTypeRunEndEncoded(struct ArrowSchema* schema, + enum ArrowType run_end_type); + /// \brief Set the format field of a time, timestamp, or duration schema /// /// Returns EINVAL for type that is not @@ -1360,7 +1426,7 @@ ArrowErrorCode ArrowSchemaSetTypeDateTime(struct ArrowSchema* schema, enum Arrow enum ArrowTimeUnit time_unit, const char* timezone); -/// \brief Seet the format field of a union schema +/// \brief Set the format field of a union schema /// /// Returns EINVAL for a type that is not NANOARROW_TYPE_DENSE_UNION /// or NANOARROW_TYPE_SPARSE_UNION. The specified number of children are @@ -1603,14 +1669,12 @@ static inline void ArrowBufferReset(struct ArrowBuffer* buffer); /// address and resets buffer. static inline void ArrowBufferMove(struct ArrowBuffer* src, struct ArrowBuffer* dst); -/// \brief Grow or shrink a buffer to a given capacity +/// \brief Grow or shrink a buffer to a given size /// -/// When shrinking the capacity of the buffer, the buffer is only reallocated -/// if shrink_to_fit is non-zero. Calling ArrowBufferResize() does not -/// adjust the buffer's size member except to ensure that the invariant -/// capacity >= size remains true. +/// When shrinking the size of the buffer, the buffer is only reallocated +/// if shrink_to_fit is non-zero. static inline ArrowErrorCode ArrowBufferResize(struct ArrowBuffer* buffer, - int64_t new_capacity_bytes, + int64_t new_size_bytes, char shrink_to_fit); /// \brief Ensure a buffer has at least a given additional capacity @@ -1740,15 +1804,12 @@ static inline void ArrowBitmapMove(struct ArrowBitmap* src, struct ArrowBitmap* static inline ArrowErrorCode ArrowBitmapReserve(struct ArrowBitmap* bitmap, int64_t additional_size_bits); -/// \brief Grow or shrink a bitmap to a given capacity +/// \brief Grow or shrink a bitmap to a given size /// -/// When shrinking the capacity of the bitmap, the bitmap is only reallocated -/// if shrink_to_fit is non-zero. Calling ArrowBitmapResize() does not -/// adjust the buffer's size member except when shrinking new_capacity_bits -/// to a value less than the current number of bits in the bitmap. +/// When shrinking the size of the bitmap, the bitmap is only reallocated +/// if shrink_to_fit is non-zero. static inline ArrowErrorCode ArrowBitmapResize(struct ArrowBitmap* bitmap, - int64_t new_capacity_bits, - char shrink_to_fit); + int64_t new_size_bits, char shrink_to_fit); /// \brief Reserve space for and append zero or more of the same boolean value to a bitmap static inline ArrowErrorCode ArrowBitmapAppend(struct ArrowBitmap* bitmap, @@ -2021,6 +2082,48 @@ ArrowErrorCode ArrowArrayViewSetArrayMinimal(struct ArrowArrayView* array_view, const struct ArrowArray* array, struct ArrowError* error); +/// \brief Get the number of buffers +/// +/// The number of buffers referred to by this ArrowArrayView. In may cases this can also +/// be calculated from the ArrowLayout member of the ArrowArrayView or ArrowSchemaView; +/// however, for binary view and string view types, the number of total buffers depends on +/// the number of variadic buffers. +static inline int64_t ArrowArrayViewGetNumBuffers(struct ArrowArrayView* array_view); + +/// \brief Get a view of a specific buffer from an ArrowArrayView +/// +/// This is the ArrowArrayView equivalent of ArrowArray::buffers[i] that includes +/// size information (if known). +static inline struct ArrowBufferView ArrowArrayViewGetBufferView( + struct ArrowArrayView* array_view, int64_t i); + +/// \brief Get the function of a specific buffer in an ArrowArrayView +/// +/// In may cases this can also be obtained from the ArrowLayout member of the +/// ArrowArrayView or ArrowSchemaView; however, for binary view and string view types, +/// the function of each buffer may be different between two arrays of the same type +/// depending on the number of variadic buffers. +static inline enum ArrowBufferType ArrowArrayViewGetBufferType( + struct ArrowArrayView* array_view, int64_t i); + +/// \brief Get the data type of a specific buffer in an ArrowArrayView +/// +/// In may cases this can also be obtained from the ArrowLayout member of the +/// ArrowArrayView or ArrowSchemaView; however, for binary view and string view types, +/// the data type of each buffer may be different between two arrays of the same type +/// depending on the number of variadic buffers. +static inline enum ArrowType ArrowArrayViewGetBufferDataType( + struct ArrowArrayView* array_view, int64_t i); + +/// \brief Get the element size (in bits) of a specific buffer in an ArrowArrayView +/// +/// In may cases this can also be obtained from the ArrowLayout member of the +/// ArrowArrayView or ArrowSchemaView; however, for binary view and string view types, +/// the element width of each buffer may be different between two arrays of the same type +/// depending on the number of variadic buffers. +static inline int64_t ArrowArrayViewGetBufferElementSizeBits( + struct ArrowArrayView* array_view, int64_t i); + /// \brief Performs checks on the content of an ArrowArrayView /// /// If using ArrowArrayViewSetArray() to back array_view with an ArrowArray, @@ -2033,6 +2136,19 @@ ArrowErrorCode ArrowArrayViewValidate(struct ArrowArrayView* array_view, enum ArrowValidationLevel validation_level, struct ArrowError* error); +/// \brief Compare two ArrowArrayView objects for equality +/// +/// Given two ArrowArrayView instances, place either 0 (not equal) and +/// 1 (equal) at the address pointed to by out. If the comparison determines +/// that actual and expected are not equal, a reason will be communicated via +/// error if error is non-NULL. +/// +/// Returns NANOARROW_OK if the comparison completed successfully. +ArrowErrorCode ArrowArrayViewCompare(const struct ArrowArrayView* actual, + const struct ArrowArrayView* expected, + enum ArrowCompareLevel level, int* out, + struct ArrowError* reason); + /// \brief Reset the contents of an ArrowArrayView and frees resources void ArrowArrayViewReset(struct ArrowArrayView* array_view); @@ -2040,6 +2156,10 @@ void ArrowArrayViewReset(struct ArrowArrayView* array_view); static inline int8_t ArrowArrayViewIsNull(const struct ArrowArrayView* array_view, int64_t i); +/// \brief Compute null count for an ArrowArrayView +static inline int64_t ArrowArrayViewComputeNullCount( + const struct ArrowArrayView* array_view); + /// \brief Get the type id of a union array element static inline int8_t ArrowArrayViewUnionTypeId(const struct ArrowArrayView* array_view, int64_t i); @@ -2177,6 +2297,49 @@ ArrowErrorCode ArrowBasicArrayStreamValidate(const struct ArrowArrayStream* arra extern "C" { #endif +// Modified from Arrow C++ (1eb46f76) cpp/src/arrow/chunk_resolver.h#L133-L162 +static inline int64_t ArrowResolveChunk64(int64_t index, const int64_t* offsets, + int64_t lo, int64_t hi) { + // Similar to std::upper_bound(), but slightly different as our offsets + // array always starts with 0. + int64_t n = hi - lo; + // First iteration does not need to check for n > 1 + // (lo < hi is guaranteed by the precondition). + NANOARROW_DCHECK(n > 1); + do { + const int64_t m = n >> 1; + const int64_t mid = lo + m; + if (index >= offsets[mid]) { + lo = mid; + n -= m; + } else { + n = m; + } + } while (n > 1); + return lo; +} + +static inline int64_t ArrowResolveChunk32(int32_t index, const int32_t* offsets, + int32_t lo, int32_t hi) { + // Similar to std::upper_bound(), but slightly different as our offsets + // array always starts with 0. + int32_t n = hi - lo; + // First iteration does not need to check for n > 1 + // (lo < hi is guaranteed by the precondition). + NANOARROW_DCHECK(n > 1); + do { + const int32_t m = n >> 1; + const int32_t mid = lo + m; + if (index >= offsets[mid]) { + lo = mid; + n -= m; + } else { + n = m; + } + } while (n > 1); + return lo; +} + static inline int64_t _ArrowGrowByFactor(int64_t current_capacity, int64_t new_capacity) { int64_t doubled_capacity = current_capacity * 2; if (doubled_capacity > new_capacity) { @@ -2186,6 +2349,57 @@ static inline int64_t _ArrowGrowByFactor(int64_t current_capacity, int64_t new_c } } +// float to half float conversion, adapted from Arrow Go +// https://github.com/apache/arrow/blob/main/go/arrow/float16/float16.go +static inline uint16_t ArrowFloatToHalfFloat(float value) { + union { + float f; + uint32_t b; + } u; + u.f = value; + + uint16_t sn = (uint16_t)((u.b >> 31) & 0x1); + uint16_t exp = (u.b >> 23) & 0xff; + int16_t res = (int16_t)(exp - 127 + 15); + uint16_t fc = (uint16_t)(u.b >> 13) & 0x3ff; + + if (exp == 0) { + res = 0; + } else if (exp == 0xff) { + res = 0x1f; + } else if (res > 0x1e) { + res = 0x1f; + fc = 0; + } else if (res < 0x01) { + res = 0; + fc = 0; + } + + return (uint16_t)((sn << 15) | (uint16_t)(res << 10) | fc); +} + +// half float to float conversion, adapted from Arrow Go +// https://github.com/apache/arrow/blob/main/go/arrow/float16/float16.go +static inline float ArrowHalfFloatToFloat(uint16_t value) { + uint32_t sn = (uint32_t)((value >> 15) & 0x1); + uint32_t exp = (value >> 10) & 0x1f; + uint32_t res = exp + 127 - 15; + uint32_t fc = value & 0x3ff; + + if (exp == 0) { + res = 0; + } else if (exp == 0x1f) { + res = 0xff; + } + + union { + float f; + uint32_t b; + } u; + u.b = (uint32_t)(sn << 31) | (uint32_t)(res << 23) | (uint32_t)(fc << 13); + return u.f; +} + static inline void ArrowBufferInit(struct ArrowBuffer* buffer) { buffer->data = NULL; buffer->size_bytes = 0; @@ -2195,6 +2409,8 @@ static inline void ArrowBufferInit(struct ArrowBuffer* buffer) { static inline ArrowErrorCode ArrowBufferSetAllocator( struct ArrowBuffer* buffer, struct ArrowBufferAllocator allocator) { + // This is not a perfect test for "has a buffer already been allocated" + // but is likely to catch most cases. if (buffer->data == NULL) { buffer->allocator = allocator; return NANOARROW_OK; @@ -2204,46 +2420,41 @@ static inline ArrowErrorCode ArrowBufferSetAllocator( } static inline void ArrowBufferReset(struct ArrowBuffer* buffer) { - if (buffer->data != NULL) { - buffer->allocator.free(&buffer->allocator, (uint8_t*)buffer->data, - buffer->capacity_bytes); - buffer->data = NULL; - } - - buffer->capacity_bytes = 0; - buffer->size_bytes = 0; + buffer->allocator.free(&buffer->allocator, (uint8_t*)buffer->data, + buffer->capacity_bytes); + ArrowBufferInit(buffer); } static inline void ArrowBufferMove(struct ArrowBuffer* src, struct ArrowBuffer* dst) { memcpy(dst, src, sizeof(struct ArrowBuffer)); src->data = NULL; - ArrowBufferReset(src); + ArrowBufferInit(src); } static inline ArrowErrorCode ArrowBufferResize(struct ArrowBuffer* buffer, - int64_t new_capacity_bytes, + int64_t new_size_bytes, char shrink_to_fit) { - if (new_capacity_bytes < 0) { + if (new_size_bytes < 0) { return EINVAL; } - if (new_capacity_bytes > buffer->capacity_bytes || shrink_to_fit) { - buffer->data = buffer->allocator.reallocate( - &buffer->allocator, buffer->data, buffer->capacity_bytes, new_capacity_bytes); - if (buffer->data == NULL && new_capacity_bytes > 0) { + int needs_reallocation = new_size_bytes > buffer->capacity_bytes || + (shrink_to_fit && new_size_bytes < buffer->capacity_bytes); + + if (needs_reallocation) { + buffer->data = buffer->allocator.reallocate(&buffer->allocator, buffer->data, + buffer->capacity_bytes, new_size_bytes); + + if (buffer->data == NULL && new_size_bytes > 0) { buffer->capacity_bytes = 0; buffer->size_bytes = 0; return ENOMEM; } - buffer->capacity_bytes = new_capacity_bytes; - } - - // Ensures that when shrinking that size <= capacity - if (new_capacity_bytes < buffer->size_bytes) { - buffer->size_bytes = new_capacity_bytes; + buffer->capacity_bytes = new_size_bytes; } + buffer->size_bytes = new_size_bytes; return NANOARROW_OK; } @@ -2254,13 +2465,25 @@ static inline ArrowErrorCode ArrowBufferReserve(struct ArrowBuffer* buffer, return NANOARROW_OK; } - return ArrowBufferResize( - buffer, _ArrowGrowByFactor(buffer->capacity_bytes, min_capacity_bytes), 0); + int64_t new_capacity_bytes = + _ArrowGrowByFactor(buffer->capacity_bytes, min_capacity_bytes); + buffer->data = buffer->allocator.reallocate(&buffer->allocator, buffer->data, + buffer->capacity_bytes, new_capacity_bytes); + + if (buffer->data == NULL && new_capacity_bytes > 0) { + buffer->capacity_bytes = 0; + buffer->size_bytes = 0; + return ENOMEM; + } + + buffer->capacity_bytes = new_capacity_bytes; + return NANOARROW_OK; } static inline void ArrowBufferAppendUnsafe(struct ArrowBuffer* buffer, const void* data, int64_t size_bytes) { if (size_bytes > 0) { + NANOARROW_DCHECK(buffer->data != NULL); memcpy(buffer->data + buffer->size_bytes, data, size_bytes); buffer->size_bytes += size_bytes; } @@ -2336,10 +2559,16 @@ static inline ArrowErrorCode ArrowBufferAppendBufferView(struct ArrowBuffer* buf static inline ArrowErrorCode ArrowBufferAppendFill(struct ArrowBuffer* buffer, uint8_t value, int64_t size_bytes) { + if (size_bytes == 0) { + return NANOARROW_OK; + } + NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(buffer, size_bytes)); + NANOARROW_DCHECK(buffer->data != NULL); // To help clang-tidy memset(buffer->data + buffer->size_bytes, value, size_bytes); buffer->size_bytes += size_bytes; + return NANOARROW_OK; } @@ -2456,6 +2685,8 @@ static inline void ArrowBitsUnpackInt32(const uint8_t* bits, int64_t start_offse return; } + NANOARROW_DCHECK(bits != NULL && out != NULL); + const int64_t i_begin = start_offset; const int64_t i_end = start_offset + length; const int64_t i_last_valid = i_end - 1; @@ -2498,12 +2729,18 @@ static inline void ArrowBitClear(uint8_t* bits, int64_t i) { } static inline void ArrowBitSetTo(uint8_t* bits, int64_t i, uint8_t bit_is_set) { - bits[i / 8] ^= - ((uint8_t)(-((uint8_t)(bit_is_set != 0)) ^ bits[i / 8])) & _ArrowkBitmask[i % 8]; + bits[i / 8] ^= (uint8_t)(((uint8_t)(-((uint8_t)(bit_is_set != 0)) ^ bits[i / 8])) & + _ArrowkBitmask[i % 8]); } static inline void ArrowBitsSetTo(uint8_t* bits, int64_t start_offset, int64_t length, uint8_t bits_are_set) { + if (length == 0) { + return; + } + + NANOARROW_DCHECK(bits != NULL); + const int64_t i_begin = start_offset; const int64_t i_end = start_offset + length; const uint8_t fill_byte = (uint8_t)(-bits_are_set); @@ -2547,6 +2784,8 @@ static inline int64_t ArrowBitCountSet(const uint8_t* bits, int64_t start_offset return 0; } + NANOARROW_DCHECK(bits != NULL); + const int64_t i_begin = start_offset; const int64_t i_end = start_offset + length; const int64_t i_last_valid = i_end - 1; @@ -2598,32 +2837,38 @@ static inline void ArrowBitmapMove(struct ArrowBitmap* src, struct ArrowBitmap* static inline ArrowErrorCode ArrowBitmapReserve(struct ArrowBitmap* bitmap, int64_t additional_size_bits) { int64_t min_capacity_bits = bitmap->size_bits + additional_size_bits; - if (min_capacity_bits <= (bitmap->buffer.capacity_bytes * 8)) { + int64_t min_capacity_bytes = _ArrowBytesForBits(min_capacity_bits); + int64_t current_size_bytes = bitmap->buffer.size_bytes; + int64_t current_capacity_bytes = bitmap->buffer.capacity_bytes; + + if (min_capacity_bytes <= current_capacity_bytes) { return NANOARROW_OK; } - NANOARROW_RETURN_NOT_OK( - ArrowBufferReserve(&bitmap->buffer, _ArrowBytesForBits(additional_size_bits))); + int64_t additional_capacity_bytes = min_capacity_bytes - current_size_bytes; + NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(&bitmap->buffer, additional_capacity_bytes)); + // Zero out the last byte for deterministic output in the common case + // of reserving a known remaining size. We should have returned above + // if there was not at least one additional byte to allocate; however, + // DCHECK() just to be sure. + NANOARROW_DCHECK(bitmap->buffer.capacity_bytes > current_capacity_bytes); bitmap->buffer.data[bitmap->buffer.capacity_bytes - 1] = 0; return NANOARROW_OK; } static inline ArrowErrorCode ArrowBitmapResize(struct ArrowBitmap* bitmap, - int64_t new_capacity_bits, + int64_t new_size_bits, char shrink_to_fit) { - if (new_capacity_bits < 0) { + if (new_size_bits < 0) { return EINVAL; } - int64_t new_capacity_bytes = _ArrowBytesForBits(new_capacity_bits); + int64_t new_size_bytes = _ArrowBytesForBits(new_size_bits); NANOARROW_RETURN_NOT_OK( - ArrowBufferResize(&bitmap->buffer, new_capacity_bytes, shrink_to_fit)); - - if (new_capacity_bits < bitmap->size_bits) { - bitmap->size_bits = new_capacity_bits; - } + ArrowBufferResize(&bitmap->buffer, new_size_bytes, shrink_to_fit)); + bitmap->size_bits = new_size_bits; return NANOARROW_OK; } @@ -3034,6 +3279,8 @@ static inline ArrowErrorCode _ArrowArrayAppendEmptyInternal(struct ArrowArray* a switch (private_data->layout.buffer_type[i]) { case NANOARROW_BUFFER_TYPE_NONE: + case NANOARROW_BUFFER_TYPE_VARIADIC_DATA: + case NANOARROW_BUFFER_TYPE_VARIADIC_SIZE: case NANOARROW_BUFFER_TYPE_VALIDITY: continue; case NANOARROW_BUFFER_TYPE_DATA_OFFSET: @@ -3112,6 +3359,10 @@ static inline ArrowErrorCode ArrowArrayAppendInt(struct ArrowArray* array, case NANOARROW_TYPE_FLOAT: NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFloat(data_buffer, (float)value)); break; + case NANOARROW_TYPE_HALF_FLOAT: + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppendUInt16(data_buffer, ArrowFloatToHalfFloat((float)value))); + break; case NANOARROW_TYPE_BOOL: NANOARROW_RETURN_NOT_OK(_ArrowArrayAppendBits(array, 1, value != 0, 1)); break; @@ -3162,6 +3413,10 @@ static inline ArrowErrorCode ArrowArrayAppendUInt(struct ArrowArray* array, case NANOARROW_TYPE_FLOAT: NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFloat(data_buffer, (float)value)); break; + case NANOARROW_TYPE_HALF_FLOAT: + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppendUInt16(data_buffer, ArrowFloatToHalfFloat((float)value))); + break; case NANOARROW_TYPE_BOOL: NANOARROW_RETURN_NOT_OK(_ArrowArrayAppendBits(array, 1, value != 0, 1)); break; @@ -3191,6 +3446,10 @@ static inline ArrowErrorCode ArrowArrayAppendDouble(struct ArrowArray* array, case NANOARROW_TYPE_FLOAT: NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFloat(data_buffer, (float)value)); break; + case NANOARROW_TYPE_HALF_FLOAT: + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppendUInt16(data_buffer, ArrowFloatToHalfFloat((float)value))); + break; default: return EINVAL; } @@ -3203,52 +3462,151 @@ static inline ArrowErrorCode ArrowArrayAppendDouble(struct ArrowArray* array, return NANOARROW_OK; } +// Binary views only have two fixed buffers, but be aware that they must also +// always have more 1 buffer to store variadic buffer sizes (even if there are none) +#define NANOARROW_BINARY_VIEW_FIXED_BUFFERS 2 +#define NANOARROW_BINARY_VIEW_INLINE_SIZE 12 +#define NANOARROW_BINARY_VIEW_PREFIX_SIZE 4 +#define NANOARROW_BINARY_VIEW_BLOCK_SIZE (32 << 10) // 32KB + +// The Arrow C++ implementation uses anonymous structs as members +// of the ArrowBinaryView. For Cython support in this library, we define +// those structs outside of the ArrowBinaryView +struct ArrowBinaryViewInlined { + int32_t size; + uint8_t data[NANOARROW_BINARY_VIEW_INLINE_SIZE]; +}; + +struct ArrowBinaryViewRef { + int32_t size; + uint8_t prefix[NANOARROW_BINARY_VIEW_PREFIX_SIZE]; + int32_t buffer_index; + int32_t offset; +}; + +union ArrowBinaryView { + struct ArrowBinaryViewInlined inlined; + struct ArrowBinaryViewRef ref; + int64_t alignment_dummy; +}; + +static inline int32_t ArrowArrayVariadicBufferCount(struct ArrowArray* array) { + struct ArrowArrayPrivateData* private_data = + (struct ArrowArrayPrivateData*)array->private_data; + + return private_data->n_variadic_buffers; +} + +static inline ArrowErrorCode ArrowArrayAddVariadicBuffers(struct ArrowArray* array, + int32_t nbuffers) { + const int32_t n_current_bufs = ArrowArrayVariadicBufferCount(array); + const int32_t nvariadic_bufs_needed = n_current_bufs + nbuffers; + + struct ArrowArrayPrivateData* private_data = + (struct ArrowArrayPrivateData*)array->private_data; + + private_data->variadic_buffers = (struct ArrowBuffer*)ArrowRealloc( + private_data->variadic_buffers, sizeof(struct ArrowBuffer) * nvariadic_bufs_needed); + if (private_data->variadic_buffers == NULL) { + return ENOMEM; + } + private_data->variadic_buffer_sizes = (int64_t*)ArrowRealloc( + private_data->variadic_buffer_sizes, sizeof(int64_t) * nvariadic_bufs_needed); + if (private_data->variadic_buffer_sizes == NULL) { + return ENOMEM; + } + + for (int32_t i = n_current_bufs; i < nvariadic_bufs_needed; i++) { + ArrowBufferInit(&private_data->variadic_buffers[i]); + private_data->variadic_buffer_sizes[i] = 0; + } + private_data->n_variadic_buffers = nvariadic_bufs_needed; + array->n_buffers = NANOARROW_BINARY_VIEW_FIXED_BUFFERS + 1 + nvariadic_bufs_needed; + + return NANOARROW_OK; +} + static inline ArrowErrorCode ArrowArrayAppendBytes(struct ArrowArray* array, struct ArrowBufferView value) { struct ArrowArrayPrivateData* private_data = (struct ArrowArrayPrivateData*)array->private_data; - struct ArrowBuffer* offset_buffer = ArrowArrayBuffer(array, 1); - struct ArrowBuffer* data_buffer = ArrowArrayBuffer( - array, 1 + (private_data->storage_type != NANOARROW_TYPE_FIXED_SIZE_BINARY)); - int32_t offset; - int64_t large_offset; - int64_t fixed_size_bytes = private_data->layout.element_size_bits[1] / 8; + if (private_data->storage_type == NANOARROW_TYPE_STRING_VIEW || + private_data->storage_type == NANOARROW_TYPE_BINARY_VIEW) { + struct ArrowBuffer* data_buffer = ArrowArrayBuffer(array, 1); + union ArrowBinaryView bvt; + bvt.inlined.size = (int32_t)value.size_bytes; - switch (private_data->storage_type) { - case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_BINARY: - offset = ((int32_t*)offset_buffer->data)[array->length]; - if ((((int64_t)offset) + value.size_bytes) > INT32_MAX) { - return EOVERFLOW; + if (value.size_bytes <= NANOARROW_BINARY_VIEW_INLINE_SIZE) { + memcpy(bvt.inlined.data, value.data.as_char, value.size_bytes); + memset(bvt.inlined.data + bvt.inlined.size, 0, + NANOARROW_BINARY_VIEW_INLINE_SIZE - bvt.inlined.size); + } else { + int32_t current_n_vbufs = ArrowArrayVariadicBufferCount(array); + if (current_n_vbufs == 0 || + private_data->variadic_buffers[current_n_vbufs - 1].size_bytes + + value.size_bytes > + NANOARROW_BINARY_VIEW_BLOCK_SIZE) { + const int32_t additional_bufs_needed = 1; + NANOARROW_RETURN_NOT_OK( + ArrowArrayAddVariadicBuffers(array, additional_bufs_needed)); + current_n_vbufs += additional_bufs_needed; } - offset += (int32_t)value.size_bytes; - NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(offset_buffer, &offset, sizeof(int32_t))); + const int32_t buf_index = current_n_vbufs - 1; + struct ArrowBuffer* variadic_buf = &private_data->variadic_buffers[buf_index]; + memcpy(bvt.ref.prefix, value.data.as_char, NANOARROW_BINARY_VIEW_PREFIX_SIZE); + bvt.ref.buffer_index = (int32_t)buf_index; + bvt.ref.offset = (int32_t)variadic_buf->size_bytes; NANOARROW_RETURN_NOT_OK( - ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); - break; + ArrowBufferAppend(variadic_buf, value.data.as_char, value.size_bytes)); + private_data->variadic_buffer_sizes[buf_index] = variadic_buf->size_bytes; + } + NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_buffer, &bvt, sizeof(bvt))); + } else { + struct ArrowBuffer* offset_buffer = ArrowArrayBuffer(array, 1); + struct ArrowBuffer* data_buffer = ArrowArrayBuffer( + array, 1 + (private_data->storage_type != NANOARROW_TYPE_FIXED_SIZE_BINARY)); + int32_t offset; + int64_t large_offset; + int64_t fixed_size_bytes = private_data->layout.element_size_bits[1] / 8; + + switch (private_data->storage_type) { + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_BINARY: + offset = ((int32_t*)offset_buffer->data)[array->length]; + if ((((int64_t)offset) + value.size_bytes) > INT32_MAX) { + return EOVERFLOW; + } - case NANOARROW_TYPE_LARGE_STRING: - case NANOARROW_TYPE_LARGE_BINARY: - large_offset = ((int64_t*)offset_buffer->data)[array->length]; - large_offset += value.size_bytes; - NANOARROW_RETURN_NOT_OK( - ArrowBufferAppend(offset_buffer, &large_offset, sizeof(int64_t))); - NANOARROW_RETURN_NOT_OK( - ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); - break; + offset += (int32_t)value.size_bytes; + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(offset_buffer, &offset, sizeof(int32_t))); + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); + break; - case NANOARROW_TYPE_FIXED_SIZE_BINARY: - if (value.size_bytes != fixed_size_bytes) { - return EINVAL; - } + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_LARGE_BINARY: + large_offset = ((int64_t*)offset_buffer->data)[array->length]; + large_offset += value.size_bytes; + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(offset_buffer, &large_offset, sizeof(int64_t))); + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); + break; - NANOARROW_RETURN_NOT_OK( - ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); - break; - default: - return EINVAL; + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + if (value.size_bytes != fixed_size_bytes) { + return EINVAL; + } + + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); + break; + default: + return EINVAL; + } } if (private_data->bitmap.buffer.data != NULL) { @@ -3271,8 +3629,10 @@ static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array, switch (private_data->storage_type) { case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: case NANOARROW_TYPE_BINARY: case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: return ArrowArrayAppendBytes(array, buffer_view); default: return EINVAL; @@ -3459,6 +3819,132 @@ static inline void ArrowArrayViewMove(struct ArrowArrayView* src, ArrowArrayViewInitFromType(src, NANOARROW_TYPE_UNINITIALIZED); } +static inline int64_t ArrowArrayViewGetNumBuffers(struct ArrowArrayView* array_view) { + switch (array_view->storage_type) { + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + return NANOARROW_BINARY_VIEW_FIXED_BUFFERS + array_view->n_variadic_buffers + 1; + default: + break; + } + + int64_t n_buffers = 0; + for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { + if (array_view->layout.buffer_type[i] == NANOARROW_BUFFER_TYPE_NONE) { + break; + } + + n_buffers++; + } + + return n_buffers; +} + +static inline struct ArrowBufferView ArrowArrayViewGetBufferView( + struct ArrowArrayView* array_view, int64_t i) { + switch (array_view->storage_type) { + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + if (i < NANOARROW_BINARY_VIEW_FIXED_BUFFERS) { + return array_view->buffer_views[i]; + } else if (i >= + (array_view->n_variadic_buffers + NANOARROW_BINARY_VIEW_FIXED_BUFFERS)) { + struct ArrowBufferView view; + view.data.as_int64 = array_view->variadic_buffer_sizes; + view.size_bytes = array_view->n_variadic_buffers * sizeof(double); + return view; + } else { + struct ArrowBufferView view; + view.data.data = + array_view->variadic_buffers[i - NANOARROW_BINARY_VIEW_FIXED_BUFFERS]; + view.size_bytes = + array_view->variadic_buffer_sizes[i - NANOARROW_BINARY_VIEW_FIXED_BUFFERS]; + return view; + } + default: + // We need this check to avoid -Warray-bounds from complaining + if (i >= NANOARROW_MAX_FIXED_BUFFERS) { + struct ArrowBufferView view; + view.data.data = NULL; + view.size_bytes = 0; + return view; + } else { + return array_view->buffer_views[i]; + } + } +} + +enum ArrowBufferType ArrowArrayViewGetBufferType(struct ArrowArrayView* array_view, + int64_t i) { + switch (array_view->storage_type) { + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + if (i < NANOARROW_BINARY_VIEW_FIXED_BUFFERS) { + return array_view->layout.buffer_type[i]; + } else if (i == + (array_view->n_variadic_buffers + NANOARROW_BINARY_VIEW_FIXED_BUFFERS)) { + return NANOARROW_BUFFER_TYPE_VARIADIC_SIZE; + } else { + return NANOARROW_BUFFER_TYPE_VARIADIC_DATA; + } + default: + // We need this check to avoid -Warray-bounds from complaining + if (i >= NANOARROW_MAX_FIXED_BUFFERS) { + return NANOARROW_BUFFER_TYPE_NONE; + } else { + return array_view->layout.buffer_type[i]; + } + } +} + +static inline enum ArrowType ArrowArrayViewGetBufferDataType( + struct ArrowArrayView* array_view, int64_t i) { + switch (array_view->storage_type) { + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + if (i < NANOARROW_BINARY_VIEW_FIXED_BUFFERS) { + return array_view->layout.buffer_data_type[i]; + } else if (i >= + (array_view->n_variadic_buffers + NANOARROW_BINARY_VIEW_FIXED_BUFFERS)) { + return NANOARROW_TYPE_INT64; + } else if (array_view->storage_type == NANOARROW_TYPE_BINARY_VIEW) { + return NANOARROW_TYPE_BINARY; + } else { + return NANOARROW_TYPE_STRING; + } + default: + // We need this check to avoid -Warray-bounds from complaining + if (i >= NANOARROW_MAX_FIXED_BUFFERS) { + return NANOARROW_TYPE_UNINITIALIZED; + } else { + return array_view->layout.buffer_data_type[i]; + } + } +} + +static inline int64_t ArrowArrayViewGetBufferElementSizeBits( + struct ArrowArrayView* array_view, int64_t i) { + switch (array_view->storage_type) { + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + if (i < NANOARROW_BINARY_VIEW_FIXED_BUFFERS) { + return array_view->layout.element_size_bits[i]; + } else if (i >= + (array_view->n_variadic_buffers + NANOARROW_BINARY_VIEW_FIXED_BUFFERS)) { + return sizeof(int64_t) * 8; + } else { + return 0; + } + default: + // We need this check to avoid -Warray-bounds from complaining + if (i >= NANOARROW_MAX_FIXED_BUFFERS) { + return 0; + } else { + return array_view->layout.element_size_bits[i]; + } + } +} + static inline int8_t ArrowArrayViewIsNull(const struct ArrowArrayView* array_view, int64_t i) { const uint8_t* validity_buffer = array_view->buffer_views[0].data.as_uint8; @@ -3475,12 +3961,37 @@ static inline int8_t ArrowArrayViewIsNull(const struct ArrowArrayView* array_vie } } +static inline int64_t ArrowArrayViewComputeNullCount( + const struct ArrowArrayView* array_view) { + if (array_view->length == 0) { + return 0; + } + + switch (array_view->storage_type) { + case NANOARROW_TYPE_NA: + return array_view->length; + case NANOARROW_TYPE_DENSE_UNION: + case NANOARROW_TYPE_SPARSE_UNION: + // Unions are "never null" in Arrow land + return 0; + default: + break; + } + + const uint8_t* validity_buffer = array_view->buffer_views[0].data.as_uint8; + if (validity_buffer == NULL) { + return 0; + } + return array_view->length - + ArrowBitCountSet(validity_buffer, array_view->offset, array_view->length); +} + static inline int8_t ArrowArrayViewUnionTypeId(const struct ArrowArrayView* array_view, int64_t i) { switch (array_view->storage_type) { case NANOARROW_TYPE_DENSE_UNION: case NANOARROW_TYPE_SPARSE_UNION: - return array_view->buffer_views[0].data.as_int8[i]; + return array_view->buffer_views[0].data.as_int8[array_view->offset + i]; default: return -1; } @@ -3500,9 +4011,9 @@ static inline int64_t ArrowArrayViewUnionChildOffset( const struct ArrowArrayView* array_view, int64_t i) { switch (array_view->storage_type) { case NANOARROW_TYPE_DENSE_UNION: - return array_view->buffer_views[1].data.as_int32[i]; + return array_view->buffer_views[1].data.as_int32[array_view->offset + i]; case NANOARROW_TYPE_SPARSE_UNION: - return i; + return array_view->offset + i; default: return -1; } @@ -3520,6 +4031,20 @@ static inline int64_t ArrowArrayViewListChildOffset( } } +static struct ArrowBufferView ArrowArrayViewGetBytesFromViewArrayUnsafe( + const struct ArrowArrayView* array_view, int64_t i) { + const union ArrowBinaryView* bv = &array_view->buffer_views[1].data.as_binary_view[i]; + struct ArrowBufferView out = {{NULL}, bv->inlined.size}; + if (bv->inlined.size <= NANOARROW_BINARY_VIEW_INLINE_SIZE) { + out.data.as_uint8 = bv->inlined.data; + return out; + } + + out.data.data = array_view->variadic_buffers[bv->ref.buffer_index]; + out.data.as_uint8 += bv->ref.offset; + return out; +} + static inline int64_t ArrowArrayViewGetIntUnsafe(const struct ArrowArrayView* array_view, int64_t i) { const struct ArrowBufferView* data_view = &array_view->buffer_views[1]; @@ -3546,6 +4071,8 @@ static inline int64_t ArrowArrayViewGetIntUnsafe(const struct ArrowArrayView* ar return (int64_t)data_view->data.as_double[i]; case NANOARROW_TYPE_FLOAT: return (int64_t)data_view->data.as_float[i]; + case NANOARROW_TYPE_HALF_FLOAT: + return (int64_t)ArrowHalfFloatToFloat(data_view->data.as_uint16[i]); case NANOARROW_TYPE_BOOL: return ArrowBitGet(data_view->data.as_uint8, i); default: @@ -3579,6 +4106,8 @@ static inline uint64_t ArrowArrayViewGetUIntUnsafe( return (uint64_t)data_view->data.as_double[i]; case NANOARROW_TYPE_FLOAT: return (uint64_t)data_view->data.as_float[i]; + case NANOARROW_TYPE_HALF_FLOAT: + return (uint64_t)ArrowHalfFloatToFloat(data_view->data.as_uint16[i]); case NANOARROW_TYPE_BOOL: return ArrowBitGet(data_view->data.as_uint8, i); default: @@ -3611,6 +4140,8 @@ static inline double ArrowArrayViewGetDoubleUnsafe( return data_view->data.as_double[i]; case NANOARROW_TYPE_FLOAT: return data_view->data.as_float[i]; + case NANOARROW_TYPE_HALF_FLOAT: + return ArrowHalfFloatToFloat(data_view->data.as_uint16[i]); case NANOARROW_TYPE_BOOL: return ArrowBitGet(data_view->data.as_uint8, i); default: @@ -3642,6 +4173,14 @@ static inline struct ArrowStringView ArrowArrayViewGetStringUnsafe( view.size_bytes = array_view->layout.element_size_bits[1] / 8; view.data = array_view->buffer_views[1].data.as_char + (i * view.size_bytes); break; + case NANOARROW_TYPE_STRING_VIEW: + case NANOARROW_TYPE_BINARY_VIEW: { + struct ArrowBufferView buf_view = + ArrowArrayViewGetBytesFromViewArrayUnsafe(array_view, i); + view.data = buf_view.data.as_char; + view.size_bytes = buf_view.size_bytes; + break; + } default: view.data = NULL; view.size_bytes = 0; @@ -3676,6 +4215,10 @@ static inline struct ArrowBufferView ArrowArrayViewGetBytesUnsafe( view.data.as_uint8 = array_view->buffer_views[1].data.as_uint8 + (i * view.size_bytes); break; + case NANOARROW_TYPE_STRING_VIEW: + case NANOARROW_TYPE_BINARY_VIEW: + view = ArrowArrayViewGetBytesFromViewArrayUnsafe(array_view, i); + break; default: view.data.data = NULL; view.size_bytes = 0; diff --git a/c/vendor/nanoarrow/nanoarrow.hpp b/c/vendor/nanoarrow/nanoarrow.hpp index 8d5b841e28..16c2e55b9f 100644 --- a/c/vendor/nanoarrow/nanoarrow.hpp +++ b/c/vendor/nanoarrow/nanoarrow.hpp @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -#include +#include +#include #include #include @@ -82,6 +83,23 @@ class Exception : public std::exception { /// @} +namespace literals { + +/// \defgroup nanoarrow_hpp-string_view_helpers ArrowStringView helpers +/// +/// Factories and equality comparison for ArrowStringView. +/// +/// @{ + +/// \brief User literal operator allowing ArrowStringView construction like "str"_asv +inline ArrowStringView operator"" _asv(const char* data, std::size_t size_bytes) { + return {data, static_cast(size_bytes)}; +} + +// @} + +} // namespace literals + namespace internal { /// \defgroup nanoarrow_hpp-unique_base Base classes for Unique wrappers @@ -199,10 +217,16 @@ template class Unique { public: /// \brief Construct an invalid instance of T holding no resources - Unique() { init_pointer(&data_); } + Unique() { + std::memset(&data_, 0, sizeof(data_)); + init_pointer(&data_); + } /// \brief Move and take ownership of data - Unique(T* data) { move_pointer(data, &data_); } + Unique(T* data) { + std::memset(&data_, 0, sizeof(data_)); + move_pointer(data, &data_); + } /// \brief Move and take ownership of data wrapped by rhs Unique(Unique&& rhs) : Unique(rhs.get()) {} @@ -241,6 +265,15 @@ class Unique { T data_; }; +template +static inline void DeallocateWrappedBuffer(struct ArrowBufferAllocator* allocator, + uint8_t* ptr, int64_t size) { + NANOARROW_UNUSED(ptr); + NANOARROW_UNUSED(size); + auto obj = reinterpret_cast(allocator->private_data); + delete obj; +} + /// @} } // namespace internal @@ -273,6 +306,51 @@ using UniqueArrayView = internal::Unique; /// @} +/// \defgroup nanoarrow_hpp-buffer Buffer helpers +/// +/// Helpers to wrap buffer-like C++ objects as ArrowBuffer objects that can +/// be used to build ArrowArray objects. +/// +/// @{ + +/// \brief Initialize a buffer wrapping an arbitrary C++ object +/// +/// Initializes a buffer with a release callback that deletes the moved obj +/// when ArrowBufferReset is called. This version is useful for wrapping +/// an object whose .data() member is missing or unrelated to the buffer +/// value that is destined for a the buffer of an ArrowArray. T must be movable. +template +static inline void BufferInitWrapped(struct ArrowBuffer* buffer, T obj, + const uint8_t* data, int64_t size_bytes) { + T* obj_moved = new T(std::move(obj)); + buffer->data = const_cast(data); + buffer->size_bytes = size_bytes; + buffer->capacity_bytes = 0; + buffer->allocator = + ArrowBufferDeallocator(&internal::DeallocateWrappedBuffer, obj_moved); +} + +/// \brief Initialize a buffer wrapping a C++ sequence +/// +/// Specifically, this uses obj.data() to set the buffer address and +/// obj.size() * sizeof(T::value_type) to set the buffer size. This works +/// for STL containers like std::vector, std::array, and std::string. +/// This function moves obj and ensures it is deleted when ArrowBufferReset +/// is called. +template +void BufferInitSequence(struct ArrowBuffer* buffer, T obj) { + // Move before calling .data() (matters sometimes). + T* obj_moved = new T(std::move(obj)); + buffer->data = + const_cast(reinterpret_cast(obj_moved->data())); + buffer->size_bytes = obj_moved->size() * sizeof(typename T::value_type); + buffer->capacity_bytes = 0; + buffer->allocator = + ArrowBufferDeallocator(&internal::DeallocateWrappedBuffer, obj_moved); +} + +/// @} + /// \defgroup nanoarrow_hpp-array-stream ArrayStream helpers /// /// These classes provide simple ArrowArrayStream implementations that @@ -496,6 +574,359 @@ class VectorArrayStream { /// @} +namespace internal { +struct Nothing {}; + +template +class Maybe { + public: + Maybe() : nothing_(Nothing()), is_something_(false) {} + Maybe(Nothing) : Maybe() {} + + Maybe(T something) // NOLINT(google-explicit-constructor) + : something_(something), is_something_(true) {} + + explicit constexpr operator bool() const { return is_something_; } + + const T& operator*() const { return something_; } + + friend inline bool operator==(Maybe l, Maybe r) { + if (l.is_something_ != r.is_something_) return false; + return l.is_something_ ? l.something_ == r.something_ : true; + } + friend inline bool operator!=(Maybe l, Maybe r) { return !(l == r); } + + T value_or(T val) const { return is_something_ ? something_ : val; } + + private: + // When support for gcc 4.8 is dropped, we should also assert + // is_trivially_copyable::value + static_assert(std::is_trivially_destructible::value, ""); + + union { + Nothing nothing_; + T something_; + }; + bool is_something_; +}; + +template +struct RandomAccessRange { + Get get; + int64_t size; + + using value_type = decltype(std::declval()(0)); + + struct const_iterator { + int64_t i; + const RandomAccessRange* range; + bool operator==(const_iterator other) const { return i == other.i; } + bool operator!=(const_iterator other) const { return i != other.i; } + const_iterator& operator++() { return ++i, *this; } + value_type operator*() const { return range->get(i); } + }; + + const_iterator begin() const { return {0, this}; } + const_iterator end() const { return {size, this}; } +}; + +template +struct InputRange { + Next next; + using ValueOrFalsy = decltype(std::declval()()); + + static_assert(std::is_constructible::value, ""); + static_assert(std::is_default_constructible::value, ""); + using value_type = decltype(*std::declval()); + + struct iterator { + InputRange* range; + ValueOrFalsy stashed; + + bool operator==(iterator other) const { + return static_cast(stashed) == static_cast(other.stashed); + } + bool operator!=(iterator other) const { return !(*this == other); } + + iterator& operator++() { + stashed = range->next(); + return *this; + } + value_type operator*() const { return *stashed; } + }; + + iterator begin() { return {this, next()}; } + iterator end() { return {this, ValueOrFalsy()}; } +}; +} // namespace internal + +/// \defgroup nanoarrow_hpp-range_for Range-for helpers +/// +/// The Arrow C Data interface and the Arrow C Stream interface represent +/// data which can be iterated through using C++'s range-for statement. +/// +/// @{ + +/// \brief An object convertible to any empty optional +constexpr internal::Nothing NA{}; + +/// \brief A range-for compatible wrapper for ArrowArray of fixed size type +/// +/// Provides a sequence of optional copied from each non-null +/// slot of the wrapped array (null slots result in empty optionals). +template +class ViewArrayAs { + private: + struct Get { + const uint8_t* validity; + const void* values; + int64_t offset; + + internal::Maybe operator()(int64_t i) const { + i += offset; + if (validity == nullptr || ArrowBitGet(validity, i)) { + if (std::is_same::value) { + return ArrowBitGet(static_cast(values), i); + } else { + return static_cast(values)[i]; + } + } + return NA; + } + }; + + internal::RandomAccessRange range_; + + public: + ViewArrayAs(const ArrowArrayView* array_view) + : range_{ + Get{ + array_view->buffer_views[0].data.as_uint8, + array_view->buffer_views[1].data.data, + array_view->offset, + }, + array_view->length, + } {} + + ViewArrayAs(const ArrowArray* array) + : range_{ + Get{ + static_cast(array->buffers[0]), + array->buffers[1], + /*offset=*/0, + }, + array->length, + } {} + + using value_type = typename internal::RandomAccessRange::value_type; + using const_iterator = typename internal::RandomAccessRange::const_iterator; + const_iterator begin() const { return range_.begin(); } + const_iterator end() const { return range_.end(); } + value_type operator[](int64_t i) const { return range_.get(i); } +}; + +/// \brief A range-for compatible wrapper for ArrowArray of binary or utf8 +/// +/// Provides a sequence of optional referencing each non-null +/// slot of the wrapped array (null slots result in empty optionals). Large +/// binary and utf8 arrays can be wrapped by specifying 64 instead of 32 for +/// the template argument. +template +class ViewArrayAsBytes { + private: + static_assert(OffsetSize == 32 || OffsetSize == 64, ""); + using OffsetType = typename std::conditional::type; + + struct Get { + const uint8_t* validity; + const void* offsets; + const char* data; + int64_t offset; + + internal::Maybe operator()(int64_t i) const { + i += offset; + auto* offsets = static_cast(this->offsets); + if (validity == nullptr || ArrowBitGet(validity, i)) { + return ArrowStringView{data + offsets[i], offsets[i + 1] - offsets[i]}; + } + return NA; + } + }; + + internal::RandomAccessRange range_; + + public: + ViewArrayAsBytes(const ArrowArrayView* array_view) + : range_{ + Get{ + array_view->buffer_views[0].data.as_uint8, + array_view->buffer_views[1].data.data, + array_view->buffer_views[2].data.as_char, + array_view->offset, + }, + array_view->length, + } {} + + ViewArrayAsBytes(const ArrowArray* array) + : range_{ + Get{ + static_cast(array->buffers[0]), + array->buffers[1], + static_cast(array->buffers[2]), + /*offset=*/0, + }, + array->length, + } {} + + using value_type = typename internal::RandomAccessRange::value_type; + using const_iterator = typename internal::RandomAccessRange::const_iterator; + const_iterator begin() const { return range_.begin(); } + const_iterator end() const { return range_.end(); } + value_type operator[](int64_t i) const { return range_.get(i); } +}; + +/// \brief A range-for compatible wrapper for ArrowArray of fixed size binary +/// +/// Provides a sequence of optional referencing each non-null +/// slot of the wrapped array (null slots result in empty optionals). +class ViewArrayAsFixedSizeBytes { + private: + struct Get { + const uint8_t* validity; + const char* data; + int64_t offset; + int fixed_size; + + internal::Maybe operator()(int64_t i) const { + i += offset; + if (validity == nullptr || ArrowBitGet(validity, i)) { + return ArrowStringView{data + i * fixed_size, fixed_size}; + } + return NA; + } + }; + + internal::RandomAccessRange range_; + + public: + ViewArrayAsFixedSizeBytes(const ArrowArrayView* array_view, int fixed_size) + : range_{ + Get{ + array_view->buffer_views[0].data.as_uint8, + array_view->buffer_views[1].data.as_char, + array_view->offset, + fixed_size, + }, + array_view->length, + } {} + + ViewArrayAsFixedSizeBytes(const ArrowArray* array, int fixed_size) + : range_{ + Get{ + static_cast(array->buffers[0]), + static_cast(array->buffers[1]), + /*offset=*/0, + fixed_size, + }, + array->length, + } {} + + using value_type = typename internal::RandomAccessRange::value_type; + using const_iterator = typename internal::RandomAccessRange::const_iterator; + const_iterator begin() const { return range_.begin(); } + const_iterator end() const { return range_.end(); } + value_type operator[](int64_t i) const { return range_.get(i); } +}; + +/// \brief A range-for compatible wrapper for ArrowArrayStream +/// +/// Provides a sequence of ArrowArray& referencing the most recent array drawn +/// from the wrapped stream. (Each array may be moved from if necessary.) +/// When streams terminate due to an error, the error code and message are +/// available for inspection through the code() and error() member functions +/// respectively. Failure to inspect the error code will result in +/// an assertion failure. The number of arrays drawn from the stream is also +/// available through the count() member function. +class ViewArrayStream { + public: + ViewArrayStream(ArrowArrayStream* stream, ArrowErrorCode* code, ArrowError* error) + : code_{code}, error_{error} { + // Using a slightly more verbose constructor to silence a warning that occurs + // on some versions of MSVC. + range_.next.self = this; + range_.next.stream = stream; + } + + ViewArrayStream(ArrowArrayStream* stream, ArrowError* error) + : ViewArrayStream{stream, &internal_code_, error} {} + + ViewArrayStream(ArrowArrayStream* stream) + : ViewArrayStream{stream, &internal_code_, &internal_error_} {} + + // disable copy/move of this view, since its error references may point into itself + ViewArrayStream(ViewArrayStream&&) = delete; + ViewArrayStream& operator=(ViewArrayStream&&) = delete; + ViewArrayStream(const ViewArrayStream&) = delete; + ViewArrayStream& operator=(const ViewArrayStream&) = delete; + + // ensure the error code of this stream was accessed at least once + ~ViewArrayStream() { NANOARROW_DCHECK(code_was_accessed_); } + + private: + struct Next { + ViewArrayStream* self; + ArrowArrayStream* stream; + UniqueArray array; + + ArrowArray* operator()() { + array.reset(); + *self->code_ = ArrowArrayStreamGetNext(stream, array.get(), self->error_); + + if (array->release != nullptr) { + NANOARROW_DCHECK(*self->code_ == NANOARROW_OK); + ++self->count_; + return array.get(); + } + + return nullptr; + } + }; + + internal::InputRange range_; + ArrowErrorCode* code_; + ArrowError* error_; + ArrowError internal_error_ = {}; + ArrowErrorCode internal_code_; + bool code_was_accessed_ = false; + int count_ = 0; + + public: + using value_type = typename internal::InputRange::value_type; + using iterator = typename internal::InputRange::iterator; + iterator begin() { return range_.begin(); } + iterator end() { return range_.end(); } + + /// The error code which caused this stream to terminate, if any. + ArrowErrorCode code() { + code_was_accessed_ = true; + return *code_; + } + /// The error message which caused this stream to terminate, if any. + ArrowError* error() { return error_; } + + /// The number of arrays streamed so far. + int count() const { return count_; } +}; + +/// @} + } // namespace nanoarrow +/// \brief Equality comparison operator between ArrowStringView +/// \ingroup nanoarrow_hpp-string_view_helpers +inline bool operator==(ArrowStringView l, ArrowStringView r) { + if (l.size_bytes != r.size_bytes) return false; + return memcmp(l.data, r.data, l.size_bytes) == 0; +} + #endif diff --git a/c/vendor/vendor_nanoarrow.sh b/c/vendor/vendor_nanoarrow.sh index 73887423b7..9024090fe1 100755 --- a/c/vendor/vendor_nanoarrow.sh +++ b/c/vendor/vendor_nanoarrow.sh @@ -21,7 +21,7 @@ main() { local -r repo_url="https://github.com/apache/arrow-nanoarrow" # Check releases page: https://github.com/apache/arrow-nanoarrow/releases/ - local -r commit_sha=3f83f4c48959f7a51053074672b7a330888385b1 + local -r commit_sha=33d2c8b973d8f8f424e02ac92ddeaace2a92f8dd echo "Fetching $commit_sha from $repo_url" SCRATCH=$(mktemp -d) @@ -34,21 +34,13 @@ main() { mkdir -p nanoarrow tar --strip-components 1 -C "$SCRATCH" -xf "$tarball" - # Build the bundle using cmake. We could also use the dist/ files - # but this allows us to add the symbol namespace and ensures that the - # resulting bundle is perfectly synchronized with the commit we've pulled. - pushd "$SCRATCH" - mkdir build && cd build - # Do not use "adbc" in the namespace name since our scripts expose all - # such symbols - cmake .. -DNANOARROW_BUNDLE=ON -DNANOARROW_NAMESPACE=Private - cmake --build . - cmake --install . --prefix=../dist-adbc - popd + # Build the bundle + python "$SCRATCH/ci/scripts/bundle.py" \ + --symbol-namespace=Private \ + --include-output-dir=nanoarrow \ + --source-output-dir=nanoarrow \ + --header-namespace= - cp "$SCRATCH/dist-adbc/nanoarrow.c" nanoarrow/ - cp "$SCRATCH/dist-adbc/nanoarrow.h" nanoarrow/ - cp "$SCRATCH/dist-adbc/nanoarrow.hpp" nanoarrow/ mv CMakeLists.nanoarrow.tmp nanoarrow/CMakeLists.txt } diff --git a/ci/conda/meta.yaml b/ci/conda/meta.yaml index a2ae0fceec..ac109997a2 100644 --- a/ci/conda/meta.yaml +++ b/ci/conda/meta.yaml @@ -17,8 +17,7 @@ package: name: arrow-adbc-split - # TODO: this needs to get bumped by the release process - version: 1.1.0 + version: 1.4.0 source: path: ../../ @@ -39,8 +38,8 @@ outputs: run: test: commands: - - test -f $PREFIX/include/adbc.h # [unix] - - test -f $PREFIX/include/adbc_driver_manager.h # [unix] + - test -f $PREFIX/include/arrow-adbc/adbc.h # [unix] + - test -f $PREFIX/include/arrow-adbc/adbc_driver_manager.h # [unix] - test -d $PREFIX/lib/cmake/AdbcDriverManager/ # [unix] - test -f $PREFIX/lib/pkgconfig/adbc-driver-manager.pc # [unix] - test ! -f $PREFIX/lib/libadbc_driver_manager.a # [unix] @@ -48,8 +47,8 @@ outputs: - test -f $PREFIX/lib/libadbc_driver_manager.dylib # [osx] - if not exist %LIBRARY_BIN%\adbc_driver_manager.dll exit 1 # [win] - - if not exist %LIBRARY_INC%\adbc.h exit 1 # [win] - - if not exist %LIBRARY_INC%\adbc_driver_manager.h exit 1 # [win] + - if not exist %LIBRARY_INC%\arrow-adbc\adbc.h exit 1 # [win] + - if not exist %LIBRARY_INC%\arrow-adbc\adbc_driver_manager.h exit 1 # [win] - if not exist %LIBRARY_LIB%\adbc_driver_manager.lib exit 1 # [win] - if not exist %LIBRARY_LIB%\cmake\AdbcDriverManager exit 1 # [win] - if not exist %LIBRARY_LIB%\pkgconfig\adbc-driver-manager.pc exit 1 # [win] diff --git a/ci/conda_env_cpp_lint.txt b/ci/conda_env_cpp_lint.txt index 7cc81c1e1a..bcdcd2e7cb 100644 --- a/ci/conda_env_cpp_lint.txt +++ b/ci/conda_env_cpp_lint.txt @@ -15,5 +15,5 @@ # specific language governing permissions and limitations # under the License. -clang=14.* -clang-tools=14.* +clang=18.* +clang-tools=18.* diff --git a/ci/conda_env_docs.txt b/ci/conda_env_docs.txt index 42151b8d29..26d20f8c68 100644 --- a/ci/conda_env_docs.txt +++ b/ci/conda_env_docs.txt @@ -15,18 +15,18 @@ # specific language governing permissions and limitations # under the License. -breathe doxygen -# XXX(https://github.com/apache/arrow-adbc/issues/987) -furo>=2023.09.10 +furo make # Needed to install mermaid nodejs numpydoc pytest -sphinx>=5.0 +sphinx>=8.1 sphinx-autobuild sphinx-copybutton sphinx-design sphinxext-opengraph +# Used in recipes +sqlalchemy>2 r-pkgdown diff --git a/ci/conda_env_glib.txt b/ci/conda_env_glib.txt index 5e3cd3266c..e6b76d97c3 100644 --- a/ci/conda_env_glib.txt +++ b/ci/conda_env_glib.txt @@ -16,8 +16,11 @@ # under the License. arrow-c-glib=15.0.2 -glib +glib=2.80.5 gobject-introspection meson postgresql ruby +# TODO(https://github.com/apache/arrow-adbc/issues/2176): pin for now because +# gobject-introspection uses a deprecated/removed API +setuptools <74 diff --git a/ci/conda_env_python.txt b/ci/conda_env_python.txt index cf2abeeecc..8be410973e 100644 --- a/ci/conda_env_python.txt +++ b/ci/conda_env_python.txt @@ -20,13 +20,12 @@ importlib-resources # nodejs is required by pyright nodejs >=13.0.0 pandas -pyarrow=15.0.2 +pyarrow-all pyright pytest setuptools # For integration testing -# 0.20.3 is broken on conda-forge -polars<=0.20.2 +polars protobuf python-duckdb diff --git a/ci/docker/cpp-clang-latest.dockerfile b/ci/docker/cpp-clang-latest.dockerfile new file mode 100644 index 0000000000..ff6e161e37 --- /dev/null +++ b/ci/docker/cpp-clang-latest.dockerfile @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +ARG VCPKG + +FROM debian:12 + +RUN export DEBIAN_FRONTEND=noninteractive && \ + apt-get update -y && \ + apt-get install -y curl gnupg && \ + echo "deb http://apt.llvm.org/bookworm/ llvm-toolchain-bookworm main" \ + > /etc/apt/sources.list.d/llvm.list && \ + echo "deb-src http://apt.llvm.org/bookworm/ llvm-toolchain-bookworm main" \ + >> /etc/apt/sources.list.d/llvm.list && \ + curl -L https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ + apt-get update -y && \ + apt-get install -y clang libc++abi-dev libc++-dev libomp-dev && \ + apt-get clean + +RUN export DEBIAN_FRONTEND=noninteractive && \ + apt-get install -y cmake git libpq-dev libsqlite3-dev pkg-config + +RUN curl -L -o go.tar.gz https://go.dev/dl/go1.22.5.linux-amd64.tar.gz && \ + tar -C /opt -xvf go.tar.gz && \ + echo 'export PATH=$PATH:/opt/go/bin' | tee -a ~/.bashrc diff --git a/ci/docker/python-debug.dockerfile b/ci/docker/python-debug.dockerfile new file mode 100644 index 0000000000..c17ca103dd --- /dev/null +++ b/ci/docker/python-debug.dockerfile @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +FROM ghcr.io/mamba-org/micromamba:bookworm + +ARG ARCH +ARG GO + +USER root +RUN apt update && apt install -y git make wget && apt clean + +# arm64v8 -> arm64 +RUN wget --no-verbose https://go.dev/dl/go${GO}.linux-${ARCH/v8/}.tar.gz && \ + tar -C /usr/local -xzf go${GO}.linux-${ARCH/v8/}.tar.gz && \ + rm go${GO}.linux-${ARCH/v8/}.tar.gz + +ENV PATH="/usr/local/go/bin:${PATH}" diff --git a/ci/docker/python-debug.sh b/ci/docker/python-debug.sh new file mode 100755 index 0000000000..cf03c988d9 --- /dev/null +++ b/ci/docker/python-debug.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +echo "Using debug Python ${PYTHON}" + +git config --global --add safe.directory /adbc + +# https://github.com/mamba-org/mamba/issues/3289 +cat /adbc/ci/conda_env_cpp.txt /adbc/ci/conda_env_python.txt |\ + grep -v -e '^$' |\ + grep -v -e '^#' |\ + sort |\ + tee /tmp/spec.txt + +micromamba install -c conda-forge -y \ + -f /tmp/spec.txt \ + "conda-forge/label/python_debug::python=${PYTHON}[build=*_cpython]" +micromamba clean --all -y + +export ADBC_USE_ASAN=ON +export ADBC_USE_UBSAN=ON + +env ADBC_BUILD_TESTS=OFF /adbc/ci/scripts/cpp_build.sh /adbc /adbc/build/pydebug +/adbc/ci/scripts/python_build.sh /adbc /adbc/build/pydebug +/adbc/ci/scripts/python_test.sh /adbc /adbc/build/pydebug diff --git a/ci/docker/python-wheel-manylinux-relocate.dockerfile b/ci/docker/python-wheel-manylinux-relocate.dockerfile new file mode 100644 index 0000000000..c25d7457fa --- /dev/null +++ b/ci/docker/python-wheel-manylinux-relocate.dockerfile @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +FROM debian:bookworm-slim + +RUN apt update \ + && apt install -y \ + docker.io \ + git \ + patchelf \ + python-is-python3 \ + python3-full \ + python3-pip \ + && apt clean diff --git a/ci/docker/python-wheel-manylinux.dockerfile b/ci/docker/python-wheel-manylinux.dockerfile index 2ebfb68266..1402f1a191 100644 --- a/ci/docker/python-wheel-manylinux.dockerfile +++ b/ci/docker/python-wheel-manylinux.dockerfile @@ -26,7 +26,11 @@ FROM ${REPO}:${ARCH}-python-${PYTHON}-wheel-manylinux-${MANYLINUX}-vcpkg-${VCPKG ARG ARCH ARG GO -RUN yum install -y docker +# docker is aliased to podman by AlmaLinux, but we want real Docker +# (podman is just too different) +RUN yum remove -y docker ; yum install -y yum-utils +RUN yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo +RUN yum install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin # arm64v8 -> arm64 RUN wget --no-verbose https://go.dev/dl/go${GO}.linux-${ARCH/v8/}.tar.gz && \ tar -C /usr/local -xzf go${GO}.linux-${ARCH/v8/}.tar.gz && \ diff --git a/ci/linux-packages/debian/control b/ci/linux-packages/debian/control index e93255dfca..04e09dd685 100644 --- a/ci/linux-packages/debian/control +++ b/ci/linux-packages/debian/control @@ -34,7 +34,7 @@ Build-Depends: Standards-Version: 4.5.0 Homepage: https://arrow.apache.org/adbc/ -Package: libadbc-driver-manager101 +Package: libadbc-driver-manager104 Section: libs Architecture: any Multi-Arch: same @@ -52,12 +52,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-manager101 (= ${binary:Version}) + libadbc-driver-manager104 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) driver manager . This package provides C++ header files. -Package: libadbc-driver-postgresql101 +Package: libadbc-driver-postgresql104 Section: libs Architecture: any Multi-Arch: same @@ -75,12 +75,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-postgresql101 (= ${binary:Version}) + libadbc-driver-postgresql104 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) PostgreSQL driver . This package provides CMake package, pkg-config package and so on. -Package: libadbc-driver-sqlite101 +Package: libadbc-driver-sqlite104 Section: libs Architecture: any Multi-Arch: same @@ -98,12 +98,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-sqlite101 (= ${binary:Version}) + libadbc-driver-sqlite104 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) SQLite driver . This package provides CMake package, pkg-config package and so on. -Package: libadbc-driver-flightsql101 +Package: libadbc-driver-flightsql104 Section: libs Architecture: any Multi-Arch: same @@ -121,12 +121,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-flightsql101 (= ${binary:Version}) + libadbc-driver-flightsql104 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) Flight SQL driver . This package provides CMake package, pkg-config package and so on. -Package: libadbc-driver-snowflake101 +Package: libadbc-driver-snowflake104 Section: libs Architecture: any Multi-Arch: same @@ -144,7 +144,7 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-snowflake101 (= ${binary:Version}) + libadbc-driver-snowflake104 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) Snowflake driver . This package provides CMake package, pkg-config package and so on. @@ -158,7 +158,7 @@ Pre-Depends: ${misc:Pre-Depends} Depends: ${misc:Depends}, ${shlibs:Depends}, - libadbc-driver-manager101 (= ${binary:Version}) + libadbc-driver-manager104 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) driver manager . This package provides GLib based library files. diff --git a/ci/linux-packages/debian/libadbc-driver-flightsql101.install b/ci/linux-packages/debian/libadbc-driver-flightsql104.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-flightsql101.install rename to ci/linux-packages/debian/libadbc-driver-flightsql104.install diff --git a/ci/linux-packages/debian/libadbc-driver-manager-dev.install b/ci/linux-packages/debian/libadbc-driver-manager-dev.install index 6cee53d040..ae0e3c4ff0 100644 --- a/ci/linux-packages/debian/libadbc-driver-manager-dev.install +++ b/ci/linux-packages/debian/libadbc-driver-manager-dev.install @@ -1,5 +1,6 @@ usr/include/adbc.h usr/include/adbc_driver_manager.h +usr/include/arrow-adbc/ usr/lib/*/cmake/AdbcDriverManager/ usr/lib/*/libadbc_driver_manager.a usr/lib/*/libadbc_driver_manager.so diff --git a/ci/linux-packages/debian/libadbc-driver-manager101.install b/ci/linux-packages/debian/libadbc-driver-manager104.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-manager101.install rename to ci/linux-packages/debian/libadbc-driver-manager104.install diff --git a/ci/linux-packages/debian/libadbc-driver-postgresql101.install b/ci/linux-packages/debian/libadbc-driver-postgresql104.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-postgresql101.install rename to ci/linux-packages/debian/libadbc-driver-postgresql104.install diff --git a/ci/linux-packages/debian/libadbc-driver-snowflake101.install b/ci/linux-packages/debian/libadbc-driver-snowflake104.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-snowflake101.install rename to ci/linux-packages/debian/libadbc-driver-snowflake104.install diff --git a/ci/linux-packages/debian/libadbc-driver-sqlite101.install b/ci/linux-packages/debian/libadbc-driver-sqlite104.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-sqlite101.install rename to ci/linux-packages/debian/libadbc-driver-sqlite104.install diff --git a/ci/linux-packages/yum/apache-arrow-adbc.spec.in b/ci/linux-packages/yum/apache-arrow-adbc.spec.in index 9af14a2bf3..be4b61c5ff 100644 --- a/ci/linux-packages/yum/apache-arrow-adbc.spec.in +++ b/ci/linux-packages/yum/apache-arrow-adbc.spec.in @@ -122,6 +122,7 @@ Libraries and header files for ADBC driver manager. %license LICENSE.txt NOTICE.txt %{_includedir}/adbc.h %{_includedir}/adbc_driver_manager.h +%{_includedir}/arrow-adbc %{_libdir}/cmake/AdbcDriverManager/ %{_libdir}/libadbc_driver_manager.a %{_libdir}/libadbc_driver_manager.so diff --git a/ci/scripts/cpp_build.ps1 b/ci/scripts/cpp_build.ps1 index f92f972fc1..c94cbf596e 100755 --- a/ci/scripts/cpp_build.ps1 +++ b/ci/scripts/cpp_build.ps1 @@ -24,10 +24,11 @@ $InstallDir = if ($Args[2] -ne $null) { $Args[2] } else { Join-Path $BuildDir "l $BuildAll = $env:BUILD_ALL -ne "0" $BuildDriverManager = ($BuildAll -and (-not ($env:BUILD_DRIVER_MANAGER -eq "0"))) -or ($env:BUILD_DRIVER_MANAGER -eq "1") +$BuildDriverBigQuery = ($BuildAll -and (-not ($env:BUILD_DRIVER_BIGQUERY -eq "0"))) -or ($env:BUILD_DRIVER_BIGQUERY -eq "1") $BuildDriverFlightSql = ($BuildAll -and (-not ($env:BUILD_DRIVER_FLIGHTSQL -eq "0"))) -or ($env:BUILD_DRIVER_FLIGHTSQL -eq "1") $BuildDriverPostgreSQL = ($BuildAll -and (-not ($env:BUILD_DRIVER_POSTGRESQL -eq "0"))) -or ($env:BUILD_DRIVER_POSTGRESQL -eq "1") -$BuildDriverSqlite = ($BuildAll -and (-not ($env:BUILD_DRIVER_SQLITE -eq "0"))) -or ($env:BUILD_DRIVER_SQLITE -eq "1") $BuildDriverSnowflake = ($BuildAll -and (-not ($env:BUILD_DRIVER_SNOWFLAKE -eq "0"))) -or ($env:BUILD_DRIVER_SNOWFLAKE -eq "1") +$BuildDriverSqlite = ($BuildAll -and (-not ($env:BUILD_DRIVER_SQLITE -eq "0"))) -or ($env:BUILD_DRIVER_SQLITE -eq "1") function Build-Subproject { New-Item -ItemType Directory -Force -Path $BuildDir | Out-Null @@ -40,10 +41,11 @@ function Build-Subproject { -DADBC_BUILD_STATIC=OFF ` -DADBC_BUILD_TESTS=ON ` -DADBC_DRIVER_MANAGER="$($BuildDriverManager)" ` + -DADBC_DRIVER_BIGQUERY="$($BuildDriverBigQuery)" ` -DADBC_DRIVER_FLIGHTSQL="$($BuildDriverFlightSql)" ` -DADBC_DRIVER_POSTGRESQL="$($BuildDriverPostgreSQL)" ` - -DADBC_DRIVER_SQLITE="$($BuildDriverSqlite)" ` -DADBC_DRIVER_SNOWFLAKE="$($BuildDriverSnowflake)" ` + -DADBC_DRIVER_SQLITE="$($BuildDriverSqlite)" ` -DCMAKE_BUILD_TYPE=Release ` -DCMAKE_INSTALL_PREFIX="$($InstallDir)" ` -DCMAKE_VERBOSE_MAKEFILE=ON diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index b5777dbc2d..9f2118f1b9 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -24,6 +24,7 @@ set -e : ${BUILD_DRIVER_SQLITE:=${BUILD_ALL}} : ${BUILD_DRIVER_FLIGHTSQL:=${BUILD_ALL}} : ${BUILD_DRIVER_SNOWFLAKE:=${BUILD_ALL}} +: ${BUILD_DRIVER_BIGQUERY:=${BUILD_ALL}} # Must be explicitly enabled : ${BUILD_INTEGRATION_DUCKDB:=0} @@ -56,11 +57,12 @@ build_subproject() { -DADBC_BUILD_SHARED="${ADBC_BUILD_SHARED}" \ -DADBC_BUILD_STATIC="${ADBC_BUILD_STATIC}" \ -DADBC_BUILD_TESTS="${ADBC_BUILD_TESTS}" \ + -DADBC_DRIVER_BIGQUERY="${BUILD_DRIVER_BIGQUERY}" \ + -DADBC_DRIVER_FLIGHTSQL="${BUILD_DRIVER_FLIGHTSQL}" \ -DADBC_DRIVER_MANAGER="${BUILD_DRIVER_MANAGER}" \ -DADBC_DRIVER_POSTGRESQL="${BUILD_DRIVER_POSTGRESQL}" \ - -DADBC_DRIVER_SQLITE="${BUILD_DRIVER_SQLITE}" \ - -DADBC_DRIVER_FLIGHTSQL="${BUILD_DRIVER_FLIGHTSQL}" \ -DADBC_DRIVER_SNOWFLAKE="${BUILD_DRIVER_SNOWFLAKE}" \ + -DADBC_DRIVER_SQLITE="${BUILD_DRIVER_SQLITE}" \ -DADBC_INTEGRATION_DUCKDB="${BUILD_INTEGRATION_DUCKDB}" \ -DADBC_USE_ASAN="${ADBC_USE_ASAN}" \ -DADBC_USE_UBSAN="${ADBC_USE_UBSAN}" \ diff --git a/ci/scripts/cpp_clang_tidy.sh b/ci/scripts/cpp_clang_tidy.sh index ec7e43b3c6..425626fe77 100755 --- a/ci/scripts/cpp_clang_tidy.sh +++ b/ci/scripts/cpp_clang_tidy.sh @@ -66,11 +66,12 @@ build_subproject() { run-clang-tidy \ -extra-arg=-Wno-unknown-warning-option \ + -extra-arg=-Wno-unused-command-line-argument \ -j $(nproc) \ -p "${build_dir}" \ -fix \ -quiet \ - $(jq -r ".[] | .file" "${build_dir}/compile_commands.json") + $(jq -r ".[] | .file | select(contains(\"c/vendor\") | not)" "${build_dir}/compile_commands.json") set +x popd diff --git a/ci/scripts/cpp_recipe.sh b/ci/scripts/cpp_recipe.sh index b3cf2f3ce6..7309e7ee33 100755 --- a/ci/scripts/cpp_recipe.sh +++ b/ci/scripts/cpp_recipe.sh @@ -23,10 +23,11 @@ set -e : ${ADBC_CMAKE_ARGS:=""} : ${CMAKE_BUILD_TYPE:=Debug} -main() { - local -r source_dir="${1}" - local -r install_dir="${2}" - local -r build_dir="${3}" +test_recipe() { + local -r recipe="${1}" + local -r source_dir="${2}" + local -r install_dir="${3}" + local -r build_dir="${4}" export DYLD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${install_dir}/lib" export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${install_dir}/lib" @@ -36,11 +37,12 @@ main() { pushd "${build_dir}" set -x - cmake "${source_dir}/docs/source/cpp/recipe/" \ + cmake "${source_dir}/${recipe}/" \ ${ADBC_CMAKE_ARGS} \ -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" \ -DCMAKE_INSTALL_LIBDIR=lib \ - -DCMAKE_PREFIX_PATH="${install_dir}" + -DCMAKE_PREFIX_PATH="${install_dir}" \ + -DADBC_DRIVER_EXAMPLE_BUILD_TESTS=ON set +x cmake --build . -j @@ -49,4 +51,5 @@ main() { --no-tests=error } -main "$@" +test_recipe "docs/source/cpp/recipe" "$@" +test_recipe "docs/source/cpp/recipe_driver" "$@" diff --git a/ci/scripts/cpp_test.ps1 b/ci/scripts/cpp_test.ps1 index 0eef6a82d8..e25c2e0267 100755 --- a/ci/scripts/cpp_test.ps1 +++ b/ci/scripts/cpp_test.ps1 @@ -23,10 +23,11 @@ $InstallDir = if ($Args[1] -ne $null) { $Args[1] } else { Join-Path $BuildDir "l $BuildAll = $env:BUILD_ALL -ne "0" $BuildDriverManager = ($BuildAll -and (-not ($env:BUILD_DRIVER_MANAGER -eq "0"))) -or ($env:BUILD_DRIVER_MANAGER -eq "1") +$BuildDriverBigQuery = ($BuildAll -and (-not ($env:BUILD_DRIVER_BIGQUERY -eq "0"))) -or ($env:BUILD_DRIVER_BIGQUERY -eq "1") $BuildDriverFlightSql = ($BuildAll -and (-not ($env:BUILD_DRIVER_FLIGHTSQL -eq "0"))) -or ($env:BUILD_DRIVER_FLIGHTSQL -eq "1") $BuildDriverPostgreSQL = ($BuildAll -and (-not ($env:BUILD_DRIVER_POSTGRESQL -eq "0"))) -or ($env:BUILD_DRIVER_POSTGRESQL -eq "1") -$BuildDriverSqlite = ($BuildAll -and (-not ($env:BUILD_DRIVER_SQLITE -eq "0"))) -or ($env:BUILD_DRIVER_SQLITE -eq "1") $BuildDriverSnowflake = ($BuildAll -and (-not ($env:BUILD_DRIVER_SNOWFLAKE -eq "0"))) -or ($env:BUILD_DRIVER_SNOWFLAKE -eq "1") +$BuildDriverSqlite = ($BuildAll -and (-not ($env:BUILD_DRIVER_SQLITE -eq "0"))) -or ($env:BUILD_DRIVER_SQLITE -eq "1") $env:LD_LIBRARY_PATH += ":$($InstallDir)" $env:LD_LIBRARY_PATH += ":$($InstallDir)/bin" @@ -43,18 +44,21 @@ function Test-Project { if ($BuildDriverManager) { $labels += "|driver-manager" } + if ($BuildDriverFlightSql) { + $labels += "|driver-bigquery" + } if ($BuildDriverFlightSql) { $labels += "|driver-flightsql" } if ($BuildDriverPostgreSQL) { $labels += "|driver-postgresql" } - if ($BuildDriverSqlite) { - $labels += "|driver-sqlite" - } if ($BuildDriverSnowflake) { $labels += "|driver-snowflake" } + if ($BuildDriverSqlite) { + $labels += "|driver-sqlite" + } ctest --output-on-failure --no-tests=error -L "$($labels)" if (-not $?) { exit 1 } diff --git a/ci/scripts/cpp_test.sh b/ci/scripts/cpp_test.sh index e5600b561d..15c195261c 100755 --- a/ci/scripts/cpp_test.sh +++ b/ci/scripts/cpp_test.sh @@ -24,6 +24,7 @@ set -e : ${BUILD_DRIVER_SQLITE:=${BUILD_ALL}} : ${BUILD_DRIVER_FLIGHTSQL:=${BUILD_ALL}} : ${BUILD_DRIVER_SNOWFLAKE:=${BUILD_ALL}} +: ${BUILD_DRIVER_BIGQUERY:=${BUILD_ALL}} : ${BUILD_INTEGRATION_DUCKDB:=${BUILD_ALL}} test_project() { @@ -32,6 +33,9 @@ test_project() { pushd "${build_dir}/" local labels="" + if [[ "${BUILD_DRIVER_BIGQUERY}" -gt 0 ]]; then + labels="${labels}|driver-bigquery" + fi if [[ "${BUILD_DRIVER_FLIGHTSQL}" -gt 0 ]]; then labels="${labels}|driver-flightsql" fi @@ -41,12 +45,12 @@ test_project() { if [[ "${BUILD_DRIVER_POSTGRESQL}" -gt 0 ]]; then labels="${labels}|driver-postgresql" fi - if [[ "${BUILD_DRIVER_SQLITE}" -gt 0 ]]; then - labels="${labels}|driver-sqlite" - fi if [[ "${BUILD_DRIVER_SNOWFLAKE}" -gt 0 ]]; then labels="${labels}|driver-snowflake" fi + if [[ "${BUILD_DRIVER_SQLITE}" -gt 0 ]]; then + labels="${labels}|driver-sqlite" + fi if [[ "${BUILD_INTEGRATION_DUCKDB}" -gt 0 ]]; then labels="${labels}|integration-duckdb" fi diff --git a/ci/scripts/csharp_pack.ps1 b/ci/scripts/csharp_pack.ps1 index 8d135cef9a..1b34c9bf52 100644 --- a/ci/scripts/csharp_pack.ps1 +++ b/ci/scripts/csharp_pack.ps1 @@ -13,10 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +[CmdletBinding(PositionalBinding=$false)] param ( [string]$destination=$null, - [string]$versionSuffix=$null + [string]$versionSuffix=$null, + [switch]$noBuild ) $csharpFolder = [IO.Path]::Combine($PSScriptRoot, "..", "..", "csharp") | Resolve-Path @@ -35,5 +36,9 @@ if ($versionSuffix) { Write-Host " * Version Suffix: $versionSuffix" $packArgs["-version-suffix"] = $versionSuffix } +if ($noBuild) { + Write-Host " * Pack without building" + $packArgs["-no-build"] = $true +} dotnet pack @packArgs diff --git a/ci/scripts/csharp_pack.sh b/ci/scripts/csharp_pack.sh index c1390efecb..3ee701901a 100755 --- a/ci/scripts/csharp_pack.sh +++ b/ci/scripts/csharp_pack.sh @@ -23,10 +23,7 @@ source_dir=${1}/csharp pushd ${source_dir} -if [ -z ${2-} ]; then - dotnet pack -c Release; -else - dotnet pack -c Release --version-suffix ${2}; -fi +shift +dotnet pack -c Release "$@"; popd diff --git a/ci/scripts/docs_build.sh b/ci/scripts/docs_build.sh index 271dec459f..e3a28d644f 100755 --- a/ci/scripts/docs_build.sh +++ b/ci/scripts/docs_build.sh @@ -31,8 +31,15 @@ main() { pushd "$source_dir/docs" # The project name/version don't really matter here. + python "$source_dir/docs/source/ext/doxygen_inventory.py" \ + "ADBC C" \ + "version" \ + --html-path "$source_dir/c/apidoc/html" \ + --xml-path "$source_dir/c/apidoc/xml" \ + "cpp/api" \ + "$source_dir/c/apidoc" python "$source_dir/docs/source/ext/javadoc_inventory.py" \ - "ADBC" \ + "ADBC Java" \ "version" \ "$source_dir/java/target/site/apidocs" \ "java/api" @@ -40,8 +47,11 @@ main() { # We need to determine the base URL without knowing it... # Inject a dummy URL here, and fix it up in website_build.sh export ADBC_INTERSPHINX_MAPPING_java_adbc="http://javadocs.home.arpa/;$source_dir/java/target/site/apidocs/objects.inv" + export ADBC_INTERSPHINX_MAPPING_cpp_adbc="http://doxygen.home.arpa/;$source_dir/c/apidoc/objects.inv" - make html + sphinx-build --builder html --nitpicky --fail-on-warning --keep-going source build/html + rm -rf "$source_dir/docs/build/html/cpp/api" + cp -r "$source_dir/c/apidoc/html" "$source_dir/docs/build/html/cpp/api" rm -rf "$source_dir/docs/build/html/java/api" cp -r "$source_dir/java/target/site/apidocs" "$source_dir/docs/build/html/java/api" make doctest diff --git a/ci/scripts/install_python.sh b/ci/scripts/install_python.sh index 8eda08fe78..8f19d00171 100755 --- a/ci/scripts/install_python.sh +++ b/ci/scripts/install_python.sh @@ -27,8 +27,7 @@ platforms=([windows]=Windows [linux]=Linux) declare -A versions -versions=([3.8]=3.8.10 - [3.9]=3.9.13 +versions=([3.9]=3.9.13 [3.10]=3.10.11 [3.11]=3.11.8 [3.12]=3.12.2) diff --git a/ci/scripts/java_build.sh b/ci/scripts/java_build.sh index 463fb4210e..ae004dcf16 100755 --- a/ci/scripts/java_build.sh +++ b/ci/scripts/java_build.sh @@ -36,11 +36,7 @@ main() { pushd ${source_dir}/java mvn -B clean \ install \ - assembly:single \ - source:jar \ - javadoc:jar \ -Papache-release \ - -DdescriptorId=source-release \ -T 2C \ -DskipTests \ -Dgpg.skip diff --git a/ci/scripts/python_build.ps1 b/ci/scripts/python_build.ps1 index 14deea332b..95298ca6c0 100755 --- a/ci/scripts/python_build.ps1 +++ b/ci/scripts/python_build.ps1 @@ -22,17 +22,19 @@ $SourceDir = $Args[0] $BuildDir = $Args[1] $BuildAll = $env:BUILD_ALL -ne "0" +$BuildDriverBigQuery = ($BuildAll -and (-not ($env:BUILD_DRIVER_BIGQUERY -eq "0"))) -or ($env:BUILD_DRIVER_BIGQUERY -eq "1") $BuildDriverFlightSql = ($BuildAll -and (-not ($env:BUILD_DRIVER_FLIGHTSQL -eq "0"))) -or ($env:BUILD_DRIVER_FLIGHTSQL -eq "1") $BuildDriverManager = ($BuildAll -and (-not ($env:BUILD_DRIVER_MANAGER -eq "0"))) -or ($env:BUILD_DRIVER_MANAGER -eq "1") $BuildDriverPostgreSQL = ($BuildAll -and (-not ($env:BUILD_DRIVER_POSTGRESQL -eq "0"))) -or ($env:BUILD_DRIVER_POSTGRESQL -eq "1") -$BuildDriverSqlite = ($BuildAll -and (-not ($env:BUILD_DRIVER_SQLITE -eq "0"))) -or ($env:BUILD_DRIVER_SQLITE -eq "1") $BuildDriverSnowflake = ($BuildAll -and (-not ($env:BUILD_DRIVER_SNOWFLAKE -eq "0"))) -or ($env:BUILD_DRIVER_SNOWFLAKE -eq "1") +$BuildDriverSqlite = ($BuildAll -and (-not ($env:BUILD_DRIVER_SQLITE -eq "0"))) -or ($env:BUILD_DRIVER_SQLITE -eq "1") cmake -S "$($SourceDir)\c" -B $BuildDir ` -DADBC_DRIVER_MANAGER=$BuildDriverManager ` + -DADBC_DRIVER_BIGQUERY=$BuildDriverFlightSql ` -DADBC_DRIVER_FLIGHTSQL=$BuildDriverFlightSql ` -DADBC_DRIVER_POSTGRESQL=$BuildDriverPostgreSQL ` - -DADBC_DRIVER_SQLITE=$BuildDriverSqlite ` -DADBC_DRIVER_SNOWFLAKE=$BuildDriverSnowflake ` + -DADBC_DRIVER_SQLITE=$BuildDriverSqlite ` -DADBC_BUILD_PYTHON=ON cmake --build $BuildDir --target python diff --git a/ci/scripts/python_build.sh b/ci/scripts/python_build.sh index 76f125f590..4faa2e541e 100755 --- a/ci/scripts/python_build.sh +++ b/ci/scripts/python_build.sh @@ -19,6 +19,7 @@ set -e : ${BUILD_ALL:=1} +: ${BUILD_DRIVER_BIGQUERY:=${BUILD_ALL}} : ${BUILD_DRIVER_FLIGHTSQL:=${BUILD_ALL}} : ${BUILD_DRIVER_MANAGER:=${BUILD_ALL}} : ${BUILD_DRIVER_POSTGRESQL:=${BUILD_ALL}} @@ -43,6 +44,7 @@ main() { -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" \ -DADBC_USE_ASAN="${ADBC_USE_ASAN}" \ -DADBC_USE_UBSAN="${ADBC_USE_UBSAN}" \ + -DADBC_DRIVER_BIGQUERY=${BUILD_DRIVER_MANAGER} \ -DADBC_DRIVER_MANAGER=${BUILD_DRIVER_MANAGER} \ -DADBC_DRIVER_FLIGHTSQL=${BUILD_DRIVER_FLIGHTSQL} \ -DADBC_DRIVER_POSTGRESQL=${BUILD_DRIVER_POSTGRESQL} \ diff --git a/ci/scripts/python_test.ps1 b/ci/scripts/python_test.ps1 index c2e11b4348..bef8074661 100755 --- a/ci/scripts/python_test.ps1 +++ b/ci/scripts/python_test.ps1 @@ -24,10 +24,11 @@ $InstallDir = if ($Args[2] -ne $null) { $Args[2] } else { Join-Path $BuildDir "l $BuildAll = $env:BUILD_ALL -ne "0" $BuildDriverManager = ($BuildAll -and (-not ($env:BUILD_DRIVER_MANAGER -eq "0"))) -or ($env:BUILD_DRIVER_MANAGER -eq "1") +$BuildDriverBigQuery = ($BuildAll -and (-not ($env:BUILD_DRIVER_BIGQUERY -eq "0"))) -or ($env:BUILD_DRIVER_BIGQUERY -eq "1") $BuildDriverFlightSql = ($BuildAll -and (-not ($env:BUILD_DRIVER_FLIGHTSQL -eq "0"))) -or ($env:BUILD_DRIVER_FLIGHTSQL -eq "1") $BuildDriverPostgreSQL = ($BuildAll -and (-not ($env:BUILD_DRIVER_POSTGRESQL -eq "0"))) -or ($env:BUILD_DRIVER_POSTGRESQL -eq "1") -$BuildDriverSqlite = ($BuildAll -and (-not ($env:BUILD_DRIVER_SQLITE -eq "0"))) -or ($env:BUILD_DRIVER_SQLITE -eq "1") $BuildDriverSnowflake = ($BuildAll -and (-not ($env:BUILD_DRIVER_SNOWFLAKE -eq "0"))) -or ($env:BUILD_DRIVER_SNOWFLAKE -eq "1") +$BuildDriverSqlite = ($BuildAll -and (-not ($env:BUILD_DRIVER_SQLITE -eq "0"))) -or ($env:BUILD_DRIVER_SQLITE -eq "1") function Build-Subproject { $Subproject = $Args[0] @@ -55,6 +56,9 @@ if ($BuildDriverManager) { $env:PATH += ";$($SqliteDir)" Build-Subproject adbc_driver_manager } +if ($BuildDriverBigQuery) { + Build-Subproject adbc_driver_bigquery +} if ($BuildDriverFlightSql) { Build-Subproject adbc_driver_flightsql } diff --git a/ci/scripts/python_test.sh b/ci/scripts/python_test.sh index 6f95b5898e..40de05e036 100755 --- a/ci/scripts/python_test.sh +++ b/ci/scripts/python_test.sh @@ -21,6 +21,7 @@ set -e : ${ADBC_USE_ASAN:=OFF} : ${ADBC_USE_UBSAN:=OFF} : ${BUILD_ALL:=1} +: ${BUILD_DRIVER_BIGQUERY:=${BUILD_ALL}} : ${BUILD_DRIVER_FLIGHTSQL:=${BUILD_ALL}} : ${BUILD_DRIVER_MANAGER:=${BUILD_ALL}} : ${BUILD_DRIVER_POSTGRESQL:=${BUILD_ALL}} @@ -72,6 +73,10 @@ main() { install_dir="${build_dir}/local" fi + if [[ "${BUILD_DRIVER_BIGQUERY}" -gt 0 ]]; then + test_subproject "${source_dir}" "${install_dir}" adbc_driver_bigquery + fi + if [[ "${BUILD_DRIVER_FLIGHTSQL}" -gt 0 ]]; then test_subproject "${source_dir}" "${install_dir}" adbc_driver_flightsql fi diff --git a/ci/scripts/python_util.sh b/ci/scripts/python_util.sh index 027caf4980..abeabd78ba 100644 --- a/ci/scripts/python_util.sh +++ b/ci/scripts/python_util.sh @@ -19,7 +19,25 @@ set -ex -COMPONENTS="adbc_driver_manager adbc_driver_flightsql adbc_driver_postgresql adbc_driver_sqlite adbc_driver_snowflake" +COMPONENTS="adbc_driver_bigquery adbc_driver_manager adbc_driver_flightsql adbc_driver_postgresql adbc_driver_sqlite adbc_driver_snowflake" + +function find_drivers { + local -r build_dir="${1}/${VCPKG_ARCH}" + + if [[ $(uname) == "Linux" ]]; then + export ADBC_BIGQUERY_LIBRARY=${build_dir}/lib/libadbc_driver_bigquery.so + export ADBC_FLIGHTSQL_LIBRARY=${build_dir}/lib/libadbc_driver_flightsql.so + export ADBC_POSTGRESQL_LIBRARY=${build_dir}/lib/libadbc_driver_postgresql.so + export ADBC_SQLITE_LIBRARY=${build_dir}/lib/libadbc_driver_sqlite.so + export ADBC_SNOWFLAKE_LIBRARY=${build_dir}/lib/libadbc_driver_snowflake.so + else # macOS + export ADBC_BIGQUERY_LIBRARY=${build_dir}/lib/libadbc_driver_bigquery.dylib + export ADBC_FLIGHTSQL_LIBRARY=${build_dir}/lib/libadbc_driver_flightsql.dylib + export ADBC_POSTGRESQL_LIBRARY=${build_dir}/lib/libadbc_driver_postgresql.dylib + export ADBC_SQLITE_LIBRARY=${build_dir}/lib/libadbc_driver_sqlite.dylib + export ADBC_SNOWFLAKE_LIBRARY=${build_dir}/lib/libadbc_driver_snowflake.dylib + fi +} function build_drivers { local -r source_dir="$1" @@ -35,18 +53,12 @@ function build_drivers { # Add our custom triplets export VCPKG_OVERLAY_TRIPLETS="${source_dir}/ci/vcpkg/triplets/" + find_drivers "${2}" + if [[ $(uname) == "Linux" ]]; then - export ADBC_FLIGHTSQL_LIBRARY=${build_dir}/lib/libadbc_driver_flightsql.so - export ADBC_POSTGRESQL_LIBRARY=${build_dir}/lib/libadbc_driver_postgresql.so - export ADBC_SQLITE_LIBRARY=${build_dir}/lib/libadbc_driver_sqlite.so - export ADBC_SNOWFLAKE_LIBRARY=${build_dir}/lib/libadbc_driver_snowflake.so export VCPKG_DEFAULT_TRIPLET="${VCPKG_ARCH}-linux-static-release" export CMAKE_ARGUMENTS="" else # macOS - export ADBC_FLIGHTSQL_LIBRARY=${build_dir}/lib/libadbc_driver_flightsql.dylib - export ADBC_POSTGRESQL_LIBRARY=${build_dir}/lib/libadbc_driver_postgresql.dylib - export ADBC_SQLITE_LIBRARY=${build_dir}/lib/libadbc_driver_sqlite.dylib - export ADBC_SNOWFLAKE_LIBRARY=${build_dir}/lib/libadbc_driver_snowflake.dylib export VCPKG_DEFAULT_TRIPLET="${VCPKG_ARCH}-osx-static-release" if [[ "${VCPKG_ARCH}" = "x64" ]]; then export CMAKE_ARGUMENTS="-DCMAKE_OSX_ARCHITECTURES=x86_64" @@ -91,6 +103,7 @@ function build_drivers { ${CMAKE_ARGUMENTS} \ -DVCPKG_OVERLAY_TRIPLETS="${VCPKG_OVERLAY_TRIPLETS}" \ -DVCPKG_TARGET_TRIPLET="${VCPKG_DEFAULT_TRIPLET}" \ + -DADBC_DRIVER_BIGQUERY=ON \ -DADBC_DRIVER_FLIGHTSQL=ON \ -DADBC_DRIVER_POSTGRESQL=ON \ -DADBC_DRIVER_SQLITE=ON \ @@ -133,7 +146,8 @@ function setup_build_vars { export CIBW_BUILD='*-manylinux_*' export CIBW_PLATFORM="linux" fi - export CIBW_SKIP="pp* ${CIBW_SKIP}" + # No PyPy, no Python 3.8 + export CIBW_SKIP="pp* cp38-* ${CIBW_SKIP}" } function test_packages { diff --git a/ci/scripts/python_wheel_unix_build.sh b/ci/scripts/python_wheel_unix_build.sh index 12b5b69dea..f0f2b86373 100755 --- a/ci/scripts/python_wheel_unix_build.sh +++ b/ci/scripts/python_wheel_unix_build.sh @@ -38,7 +38,7 @@ function check_visibility { grep ' T ' nm_arrow.log | grep -v -E '(Adbc|DriverInit|\b_init\b|\b_fini\b)' | cat - > visible_symbols.log if [[ -f visible_symbols.log && `cat visible_symbols.log | wc -l` -eq 0 ]]; then - return 0 + echo "No unexpected symbols exported by $1" else echo "== Unexpected symbols exported by $1 ==" cat visible_symbols.log @@ -46,77 +46,29 @@ function check_visibility { exit 1 fi -} -function check_wheels { - if [[ $(uname) == "Linux" ]]; then - echo "=== Tag $component wheel with manylinux${MANYLINUX_VERSION} ===" - auditwheel repair "$@" -L . -w repaired_wheels - else # macOS - echo "=== Tag $component wheel with macOS ===" - delocate-wheel -v -k -w repaired_wheels "$@" + # Also check the max glibc version, to avoid accidentally bumping our + # manylinux requirement + local -r glibc_max=2.17 + local -r glibc_requirement=$(grep -Eo 'GLIBC_\S+' nm_arrow.log | awk -F_ '{print $2}' | sort --version-sort -u | tail -n1) + local -r maxver=$(echo -e "${glibc_requirement}\n${glibc_max}" | sort --version-sort | tail -n1) + if [[ "${maxver}" != "2.17" ]]; then + echo "== glibc check failed for $1 ==" + echo "Expected ${glibc_max} but found ${glibc_requirement}" + exit 1 fi } echo "=== Set up platform variables ===" setup_build_vars "${arch}" -# XXX: when we manually retag the wheel, we have to use the right arch -# tag accounting for cross-compiling, hence the replacements -PLAT_NAME=$(python -c "import sysconfig; print(sysconfig.get_platform()\ - .replace('-x86_64', '-${PYTHON_ARCH}')\ - .replace('-arm64', '-${PYTHON_ARCH}')\ - .replace('-universal2', '-${PYTHON_ARCH}'))") -if [[ "${arch}" = "arm64v8" && "$(uname)" = "Darwin" ]]; then - # Manually override the tag in this case - CI will naively generate - # "macosx_10_9_arm64" but this isn't a 'real' tag because the first - # version of macOS supporting AArch64 was macOS 11 Big Sur - PLAT_NAME="macosx_11_0_arm64" -fi - echo "=== Building C/C++ driver components ===" # Sets ADBC_POSTGRESQL_LIBRARY, ADBC_SQLITE_LIBRARY build_drivers "${source_dir}" "${build_dir}" # Check that we don't expose any unwanted symbols +check_visibility $ADBC_BIGQUERY_LIBRARY check_visibility $ADBC_FLIGHTSQL_LIBRARY check_visibility $ADBC_POSTGRESQL_LIBRARY check_visibility $ADBC_SQLITE_LIBRARY check_visibility $ADBC_SNOWFLAKE_LIBRARY - -# https://github.com/pypa/pip/issues/7555 -# Get the latest pip so we have in-tree-build by default -python -m pip install --upgrade pip auditwheel cibuildwheel delocate setuptools wheel - -# Build with Cython debug info -export ADBC_BUILD_TYPE="debug" - -for component in $COMPONENTS; do - pushd ${source_dir}/python/$component - - echo "=== Clean build artifacts ===" - rm -rf ./build ./dist ./repaired_wheels ./$component/*.so ./$component/*.so.* - - echo "=== Check $component version ===" - python $component/_version.py - - echo "=== Building $component wheel ===" - # First, create an sdist, which 1) bundles the C++ sources and 2) - # embeds the git tag. cibuildwheel may copy into a Docker - # container during build, but it only copies the package - # directory, which omits the C++ sources and .git directory, - # causing the build to fail. - python setup.py sdist - if [[ "$component" = "adbc_driver_manager" ]]; then - python -m cibuildwheel --output-dir repaired_wheels/ dist/$component-*.tar.gz - else - python -m pip wheel --no-deps -w dist -vvv . - - # Retag the wheel - python "${script_dir}/python_wheel_fix_tag.py" --plat-name="${PLAT_NAME}" dist/$component-*.whl - - check_wheels dist/$component-*.whl - fi - - popd -done diff --git a/ci/scripts/python_wheel_unix_relocate.sh b/ci/scripts/python_wheel_unix_relocate.sh new file mode 100755 index 0000000000..dde802aeec --- /dev/null +++ b/ci/scripts/python_wheel_unix_relocate.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -ex + +arch=${1} +source_dir=${2} +build_dir=${3} +script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +source "${script_dir}/python_util.sh" + +function check_wheels { + if [[ $(uname) == "Linux" ]]; then + echo "=== Tag $component wheel with manylinux${MANYLINUX_VERSION} ===" + auditwheel repair "$@" -L . -w repaired_wheels --plat manylinux_2_17_${CIBW_ARCHS} + else # macOS + echo "=== Tag $component wheel with macOS ===" + delocate-wheel -v -k -w repaired_wheels "$@" + fi +} + +echo "=== Set up platform variables ===" +setup_build_vars "${arch}" +find_drivers "${build_dir}" + +# XXX: when we manually retag the wheel, we have to use the right arch +# tag accounting for cross-compiling, hence the replacements +PLAT_NAME=$(python -c "import sysconfig; print(sysconfig.get_platform()\ + .replace('-x86_64', '-${PYTHON_ARCH}')\ + .replace('-arm64', '-${PYTHON_ARCH}')\ + .replace('-universal2', '-${PYTHON_ARCH}'))") +if [[ "${arch}" = "arm64v8" && "$(uname)" = "Darwin" ]]; then + # Manually override the tag in this case - CI will naively generate + # "macosx_10_9_arm64" but this isn't a 'real' tag because the first + # version of macOS supporting AArch64 was macOS 11 Big Sur + PLAT_NAME="macosx_11_0_arm64" +fi + +echo "=== Relocating wheels ===" +# https://github.com/pypa/pip/issues/7555 +# Get the latest pip so we have in-tree-build by default +python -m pip install --upgrade pip auditwheel 'cibuildwheel>=2.21.2' delocate setuptools wheel + +# Build with Cython debug info +export ADBC_BUILD_TYPE="debug" + +for component in $COMPONENTS; do + pushd ${source_dir}/python/$component + + echo "=== Clean build artifacts ===" + rm -rf ./build ./dist ./repaired_wheels ./$component/*.so ./$component/*.so.* + + echo "=== Check $component version ===" + python $component/_version.py + + echo "=== Building $component wheel ===" + # First, create an sdist, which 1) bundles the C++ sources and 2) + # embeds the git tag. cibuildwheel may copy into a Docker + # container during build, but it only copies the package + # directory, which omits the C++ sources and .git directory, + # causing the build to fail. + python setup.py sdist + if [[ "$component" = "adbc_driver_manager" ]]; then + python -m cibuildwheel --output-dir repaired_wheels/ dist/$component-*.tar.gz + else + python -m pip wheel --no-deps -w dist -vvv . + + # Retag the wheel + python "${script_dir}/python_wheel_fix_tag.py" --plat-name="${PLAT_NAME}" dist/$component-*.whl + + check_wheels dist/$component-*.whl + fi + + popd +done diff --git a/ci/scripts/python_wheel_windows_build.bat b/ci/scripts/python_wheel_windows_build.bat index 4bc8a6f240..9d632cfcd6 100644 --- a/ci/scripts/python_wheel_windows_build.bat +++ b/ci/scripts/python_wheel_windows_build.bat @@ -47,15 +47,17 @@ cmake ^ -DCMAKE_TOOLCHAIN_FILE=%VCPKG_ROOT%/scripts/buildsystems/vcpkg.cmake ^ -DCMAKE_UNITY_BUILD=%CMAKE_UNITY_BUILD% ^ -DVCPKG_TARGET_TRIPLET=%VCPKG_TARGET_TRIPLET% ^ - -DADBC_DRIVER_POSTGRESQL=ON ^ - -DADBC_DRIVER_SQLITE=ON ^ + -DADBC_DRIVER_BIGQUERY=ON ^ -DADBC_DRIVER_FLIGHTSQL=ON ^ -DADBC_DRIVER_MANAGER=ON ^ + -DADBC_DRIVER_POSTGRESQL=ON ^ -DADBC_DRIVER_SNOWFLAKE=ON ^ + -DADBC_DRIVER_SQLITE=ON ^ %source_dir%\c || exit /B 1 cmake --build . --config %CMAKE_BUILD_TYPE% --target install --verbose -j || exit /B 1 +set ADBC_BIGQUERY_LIBRARY=%build_dir%\bin\adbc_driver_bigquery.dll set ADBC_FLIGHTSQL_LIBRARY=%build_dir%\bin\adbc_driver_flightsql.dll set ADBC_POSTGRESQL_LIBRARY=%build_dir%\bin\adbc_driver_postgresql.dll set ADBC_SQLITE_LIBRARY=%build_dir%\bin\adbc_driver_sqlite.dll @@ -70,7 +72,7 @@ python -m pip install --upgrade pip delvewheel wheel || exit /B 1 FOR /F %%i IN ('python -c "import sysconfig; print(sysconfig.get_platform())"') DO set PLAT_NAME=%%i -FOR %%c IN (adbc_driver_manager adbc_driver_flightsql adbc_driver_postgresql adbc_driver_sqlite adbc_driver_snowflake) DO ( +FOR %%c IN (adbc_driver_bigquery adbc_driver_manager adbc_driver_flightsql adbc_driver_postgresql adbc_driver_sqlite adbc_driver_snowflake) DO ( pushd %source_dir%\python\%%c echo "=== (%PYTHON_VERSION%) Checking %%c version ===" diff --git a/ci/scripts/python_wheel_windows_test.bat b/ci/scripts/python_wheel_windows_test.bat index f31e1c3716..963067b7bc 100644 --- a/ci/scripts/python_wheel_windows_test.bat +++ b/ci/scripts/python_wheel_windows_test.bat @@ -21,7 +21,7 @@ set source_dir=%1 echo "=== (%PYTHON_VERSION%) Installing wheels ===" -FOR %%c IN (adbc_driver_manager adbc_driver_flightsql adbc_driver_postgresql adbc_driver_sqlite adbc_driver_snowflake) DO ( +FOR %%c IN (adbc_driver_bigquery adbc_driver_manager adbc_driver_flightsql adbc_driver_postgresql adbc_driver_sqlite adbc_driver_snowflake) DO ( FOR %%w IN (%source_dir%\python\%%c\dist\*.whl) DO ( pip install --no-deps --force-reinstall %%w || exit /B 1 ) @@ -31,7 +31,7 @@ pip install importlib-resources pytest pyarrow pandas protobuf echo "=== (%PYTHON_VERSION%) Testing wheels ===" -FOR %%c IN (adbc_driver_manager adbc_driver_flightsql adbc_driver_postgresql adbc_driver_sqlite adbc_driver_snowflake) DO ( +FOR %%c IN (adbc_driver_bigquery adbc_driver_manager adbc_driver_flightsql adbc_driver_postgresql adbc_driver_sqlite adbc_driver_snowflake) DO ( echo "=== Testing %%c ===" python -c "import %%c" || exit /B 1 python -c "import %%c.dbapi" || exit /B 1 diff --git a/ci/scripts/r_build.sh b/ci/scripts/r_build.sh index 79af0e899a..fbf4d039c2 100755 --- a/ci/scripts/r_build.sh +++ b/ci/scripts/r_build.sh @@ -63,6 +63,10 @@ main() { if [[ "${BUILD_DRIVER_SNOWFLAKE}" -gt 0 ]]; then install_pkg "${source_dir}" "${install_dir}" adbcsnowflake fi + + if [[ "${BUILD_DRIVER_BIGQUERY}" -gt 0 ]]; then + install_pkg "${source_dir}" "${install_dir}" adbcbigquery + fi } main "$@" diff --git a/ci/scripts/run_cgo_drivermgr_check.sh b/ci/scripts/run_cgo_drivermgr_check.sh index a03da986d6..946ab0906d 100755 --- a/ci/scripts/run_cgo_drivermgr_check.sh +++ b/ci/scripts/run_cgo_drivermgr_check.sh @@ -32,8 +32,8 @@ main() { for f in "$@"; do fn=$(basename $f) - if ! diff -q "$f" "go/adbc/drivermgr/$fn" &>/dev/null; then - >&2 echo "OUT OF SYNC: $f differs from go/adbc/drivermgr/$fn" + if ! diff -q "$f" "go/adbc/drivermgr/arrow-adbc/$fn" &>/dev/null; then + >&2 echo "OUT OF SYNC: $f differs from go/adbc/drivermgr/arrow-adbc/$fn" popd return 1 fi diff --git a/r/adbcsqlite/src/c/driver/framework/.gitignore b/ci/scripts/rust_build.sh old mode 100644 new mode 100755 similarity index 84% rename from r/adbcsqlite/src/c/driver/framework/.gitignore rename to ci/scripts/rust_build.sh index 9e7f94dec3..42630136c0 --- a/r/adbcsqlite/src/c/driver/framework/.gitignore +++ b/ci/scripts/rust_build.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,14 +16,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -base_driver.cc -base_driver.h -base_connection.h -base_database.h -base_statement.h -catalog.h -objects.h -status.h -type_fwd.h -catalog.cc -objects.cc + +set -euxo pipefail + +source_dir="${1}/rust" + +pushd "${source_dir}" +cargo build --all-features --all-targets --workspace +popd diff --git a/ci/scripts/rust_test.sh b/ci/scripts/rust_test.sh new file mode 100755 index 0000000000..439ad5e33c --- /dev/null +++ b/ci/scripts/rust_test.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +source_dir="${1}/rust" +cpp_libs_dir="${2}" + +export LD_LIBRARY_PATH="${cpp_libs_dir}/lib:${LD_LIBRARY_PATH:-}" +export DYLD_LIBRARY_PATH="${cpp_libs_dir}/lib:${DYLD_LIBRARY_PATH:-}" + +pushd "${source_dir}" +cargo test --all-features --workspace +popd diff --git a/ci/scripts/website_build.sh b/ci/scripts/website_build.sh index 1f7f146e1c..91c62e04bb 100755 --- a/ci/scripts/website_build.sh +++ b/ci/scripts/website_build.sh @@ -46,14 +46,17 @@ main() { exit 1 fi - local -r regex='^([0-9]+\.[0-9]+\.[0-9]+)$' + # Docs use the ADBC release so it will just be 12, 13, 14, ... + local -r regex='^[0-9]+$' local directory="main" if [[ "${new_version}" =~ $regex ]]; then + echo "Adding docs for version ${new_version}" cp -r "${docs}" "${site}/${new_version}" git -C "${site}" add --force "${new_version}" directory="${new_version}" else # Assume this is dev docs + echo "Adding dev docs for version ${new_version}" rm -rf "${site}/main" cp -r "${docs}" "${site}/main" git -C "${site}" add --force "main" @@ -63,6 +66,7 @@ main() { # Fix up lazy Intersphinx links (see docs_build.sh) # Assumes GNU sed sed -i "s|http://javadocs.home.arpa/|https://arrow.apache.org/adbc/${directory}/|g" $(grep -Rl javadocs.home.arpa "${site}/${directory}/") + sed -i "s|http://doxygen.home.arpa/|https://arrow.apache.org/adbc/${directory}/|g" $(grep -Rl doxygen.home.arpa "${site}/${directory}/") git -C "${site}" add --force "${directory}" # Copy the version script and regenerate the version list @@ -81,7 +85,7 @@ main() { popd # Determine the latest stable version - local -r latest_docs=$(grep -E ';[0-9]+\.[0-9]+\.[0-9]+$' "${site}/versions.txt" | sort -t ';' --version-sort | tail -n1) + local -r latest_docs=$(grep -E ';[0-9]+(\.[0-9]+\.[0-9]+)?$' "${site}/versions.txt" | sort -t ';' --version-sort | tail -n1) if [[ -z "${latest_docs}" ]]; then echo "No stable versions found" local -r latest_dir="main" diff --git a/csharp/Apache.Arrow.Adbc.sln b/csharp/Apache.Arrow.Adbc.sln index 139d88d43e..9db956c039 100644 --- a/csharp/Apache.Arrow.Adbc.sln +++ b/csharp/Apache.Arrow.Adbc.sln @@ -10,12 +10,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Tests", "Tests", "{5BD04C26 EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Drivers", "Drivers", "{FEB257A0-4FD3-495E-9A47-9E1649755445}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Adbc.Drivers.FlightSql", "src\Drivers\FlightSql\Apache.Arrow.Adbc.Drivers.FlightSql.csproj", "{19AA450A-2F87-49BD-9122-8AD07D4C6DCE}" -EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Drivers", "Drivers", "{C7290227-E925-47E7-8B6B-A8B171645D58}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Adbc.Tests.Drivers.FlightSql", "test\Drivers\FlightSql\Apache.Arrow.Adbc.Tests.Drivers.FlightSql.csproj", "{5C15E79C-19C4-4FF4-BB82-28754FE3966B}" -EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Client", "Client", "{B6111602-2DC4-4B2F-9598-E3EE1972D3E4}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Adbc.Client", "src\Client\Apache.Arrow.Adbc.Client.csproj", "{A405F4A0-5938-4139-B2DF-ED9A05EC3D7C}" @@ -36,6 +32,14 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Adbc.Drivers.A EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Adbc.Tests.Drivers.Apache", "test\Drivers\Apache\Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj", "{714F0BD2-3A92-4D1A-8FAC-D0C0599BE3E3}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Adbc.Drivers.Interop.FlightSql", "src\Drivers\Interop\FlightSql\Apache.Arrow.Adbc.Drivers.Interop.FlightSql.csproj", "{4076D7E9-728D-4DF4-999F-658784957648}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql", "test\Drivers\Interop\FlightSql\Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql.csproj", "{C5503227-C5A7-406F-83AA-681F292EA61F}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Adbc.Drivers.FlightSql", "src\Drivers\FlightSql\Apache.Arrow.Adbc.Drivers.FlightSql.csproj", "{77D5A92F-4136-4DE7-81F4-43B981223280}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Adbc.Tests.Drivers.FlightSql", "test\Drivers\FlightSql\Apache.Arrow.Adbc.Tests.Drivers.FlightSql.csproj", "{5B27FB02-D4AE-4ACB-AD88-5E64EEB61729}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -50,14 +54,6 @@ Global {00C143BA-F1CF-4117-9DE6-E73DC4D208F8}.Debug|Any CPU.Build.0 = Debug|Any CPU {00C143BA-F1CF-4117-9DE6-E73DC4D208F8}.Release|Any CPU.ActiveCfg = Release|Any CPU {00C143BA-F1CF-4117-9DE6-E73DC4D208F8}.Release|Any CPU.Build.0 = Release|Any CPU - {19AA450A-2F87-49BD-9122-8AD07D4C6DCE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {19AA450A-2F87-49BD-9122-8AD07D4C6DCE}.Debug|Any CPU.Build.0 = Debug|Any CPU - {19AA450A-2F87-49BD-9122-8AD07D4C6DCE}.Release|Any CPU.ActiveCfg = Release|Any CPU - {19AA450A-2F87-49BD-9122-8AD07D4C6DCE}.Release|Any CPU.Build.0 = Release|Any CPU - {5C15E79C-19C4-4FF4-BB82-28754FE3966B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {5C15E79C-19C4-4FF4-BB82-28754FE3966B}.Debug|Any CPU.Build.0 = Debug|Any CPU - {5C15E79C-19C4-4FF4-BB82-28754FE3966B}.Release|Any CPU.ActiveCfg = Release|Any CPU - {5C15E79C-19C4-4FF4-BB82-28754FE3966B}.Release|Any CPU.Build.0 = Release|Any CPU {A405F4A0-5938-4139-B2DF-ED9A05EC3D7C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {A405F4A0-5938-4139-B2DF-ED9A05EC3D7C}.Debug|Any CPU.Build.0 = Debug|Any CPU {A405F4A0-5938-4139-B2DF-ED9A05EC3D7C}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -94,15 +90,29 @@ Global {714F0BD2-3A92-4D1A-8FAC-D0C0599BE3E3}.Debug|Any CPU.Build.0 = Debug|Any CPU {714F0BD2-3A92-4D1A-8FAC-D0C0599BE3E3}.Release|Any CPU.ActiveCfg = Release|Any CPU {714F0BD2-3A92-4D1A-8FAC-D0C0599BE3E3}.Release|Any CPU.Build.0 = Release|Any CPU + {4076D7E9-728D-4DF4-999F-658784957648}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4076D7E9-728D-4DF4-999F-658784957648}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4076D7E9-728D-4DF4-999F-658784957648}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4076D7E9-728D-4DF4-999F-658784957648}.Release|Any CPU.Build.0 = Release|Any CPU + {C5503227-C5A7-406F-83AA-681F292EA61F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C5503227-C5A7-406F-83AA-681F292EA61F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C5503227-C5A7-406F-83AA-681F292EA61F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C5503227-C5A7-406F-83AA-681F292EA61F}.Release|Any CPU.Build.0 = Release|Any CPU + {77D5A92F-4136-4DE7-81F4-43B981223280}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {77D5A92F-4136-4DE7-81F4-43B981223280}.Debug|Any CPU.Build.0 = Debug|Any CPU + {77D5A92F-4136-4DE7-81F4-43B981223280}.Release|Any CPU.ActiveCfg = Release|Any CPU + {77D5A92F-4136-4DE7-81F4-43B981223280}.Release|Any CPU.Build.0 = Release|Any CPU + {5B27FB02-D4AE-4ACB-AD88-5E64EEB61729}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {5B27FB02-D4AE-4ACB-AD88-5E64EEB61729}.Debug|Any CPU.Build.0 = Debug|Any CPU + {5B27FB02-D4AE-4ACB-AD88-5E64EEB61729}.Release|Any CPU.ActiveCfg = Release|Any CPU + {5B27FB02-D4AE-4ACB-AD88-5E64EEB61729}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection GlobalSection(NestedProjects) = preSolution {00C143BA-F1CF-4117-9DE6-E73DC4D208F8} = {5BD04C26-CE52-4893-8C1A-479705195CEF} - {19AA450A-2F87-49BD-9122-8AD07D4C6DCE} = {FEB257A0-4FD3-495E-9A47-9E1649755445} {C7290227-E925-47E7-8B6B-A8B171645D58} = {5BD04C26-CE52-4893-8C1A-479705195CEF} - {5C15E79C-19C4-4FF4-BB82-28754FE3966B} = {C7290227-E925-47E7-8B6B-A8B171645D58} {A405F4A0-5938-4139-B2DF-ED9A05EC3D7C} = {B6111602-2DC4-4B2F-9598-E3EE1972D3E4} {A748041C-EF9A-4E88-B6FB-9F2D6CB79170} = {FEB257A0-4FD3-495E-9A47-9E1649755445} {EA43BB7C-BC00-4701-BDF4-367880C2495C} = {C7290227-E925-47E7-8B6B-A8B171645D58} @@ -112,6 +122,10 @@ Global {35281025-2FE3-4BE0-BB76-600F93386C87} = {C7290227-E925-47E7-8B6B-A8B171645D58} {6C0D8BE1-4A23-4C2F-88B1-D2FBEA0B1903} = {FEB257A0-4FD3-495E-9A47-9E1649755445} {714F0BD2-3A92-4D1A-8FAC-D0C0599BE3E3} = {C7290227-E925-47E7-8B6B-A8B171645D58} + {4076D7E9-728D-4DF4-999F-658784957648} = {FEB257A0-4FD3-495E-9A47-9E1649755445} + {C5503227-C5A7-406F-83AA-681F292EA61F} = {C7290227-E925-47E7-8B6B-A8B171645D58} + {77D5A92F-4136-4DE7-81F4-43B981223280} = {FEB257A0-4FD3-495E-9A47-9E1649755445} + {5B27FB02-D4AE-4ACB-AD88-5E64EEB61729} = {C7290227-E925-47E7-8B6B-A8B171645D58} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {4795CF16-0FDB-4BE0-9768-5CF31564DC03} diff --git a/csharp/Directory.Build.props b/csharp/Directory.Build.props index 4b63d46d6b..c771bc43b8 100644 --- a/csharp/Directory.Build.props +++ b/csharp/Directory.Build.props @@ -29,7 +29,7 @@ Apache Arrow ADBC library Copyright 2022-2024 The Apache Software Foundation The Apache Software Foundation - 0.13.0 + 0.16.0 SNAPSHOT diff --git a/csharp/src/Apache.Arrow.Adbc/AdbcConnection11.cs b/csharp/src/Apache.Arrow.Adbc/AdbcConnection11.cs index e4f79cbbcf..5ea4337062 100644 --- a/csharp/src/Apache.Arrow.Adbc/AdbcConnection11.cs +++ b/csharp/src/Apache.Arrow.Adbc/AdbcConnection11.cs @@ -32,6 +32,7 @@ public abstract class AdbcConnection11 : IDisposable , IAsyncDisposable #endif { + ~AdbcConnection11() => Dispose(false); /// diff --git a/csharp/src/Apache.Arrow.Adbc/AdbcDriverLoader.cs b/csharp/src/Apache.Arrow.Adbc/AdbcDriverLoader.cs new file mode 100644 index 0000000000..ac861ca05f --- /dev/null +++ b/csharp/src/Apache.Arrow.Adbc/AdbcDriverLoader.cs @@ -0,0 +1,98 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.IO; +using System.Runtime.InteropServices; +using Apache.Arrow.Adbc.C; + +namespace Apache.Arrow.Adbc +{ + /// + /// Lightweight class for loading an Interop driver to .NET. + /// + public abstract class AdbcDriverLoader + { + readonly string driverShortName; + readonly string entryPoint; + + /// + /// Initializes the driver loader with the driver name and entry point. + /// + /// Short driver name, with no extension. + /// The entry point. Defaults to `AdbcDriverInit` if not provided. + /// + protected AdbcDriverLoader(string driverShortName, string entryPoint = "AdbcDriverInit") + { + if (string.IsNullOrEmpty(driverShortName)) + throw new ArgumentException("cannot be null or empty", nameof(driverShortName)); + + if (string.IsNullOrEmpty(entryPoint)) + throw new ArgumentException("cannot be null or empty", nameof(entryPoint)); + + this.driverShortName = driverShortName; + this.entryPoint = entryPoint; + } + + /// + /// Loads the Interop from the current directory using the default name and entry point. + /// + /// An based on the Flight SQL Go driver. + /// + protected AdbcDriver FindAndLoadDriver() + { + string root = "runtimes"; + string native = "native"; + string architecture = RuntimeInformation.OSArchitecture.ToString().ToLower(); + string fileName = driverShortName; + string file; + + // matches extensions in https://github.com/apache/arrow-adbc/blob/main/go/adbc/pkg/Makefile + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + file = Path.Combine(root, $"linux-{architecture}", native, $"{fileName}.so"); + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + file = Path.Combine(root, $"win-{architecture}", native, $"{fileName}.dll"); + else + file = Path.Combine(root, $"osx-{architecture}", native, $"{fileName}.dylib"); + + if (File.Exists(file)) + { + // get the full path because some .NET versions need it + file = Path.GetFullPath(file); + } + else + { + throw new FileNotFoundException($"Could not find {file}"); + } + + return LoadDriver(file, entryPoint); + } + + /// + /// Loads the Interop driver from the current directory using the default name and entry point. + /// + /// The file to load. + /// The entry point of the file. + /// An . + public static AdbcDriver LoadDriver(string file, string entryPoint) + { + AdbcDriver driver = CAdbcDriverImporter.Load(file, entryPoint); + + return driver; + } + } +} diff --git a/csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs b/csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs index 1fbd28c1a6..98eec4438d 100644 --- a/csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs +++ b/csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs @@ -16,10 +16,8 @@ */ using System; -using System.IO; using System.Threading.Tasks; using Apache.Arrow.Ipc; -using Apache.Arrow.Types; namespace Apache.Arrow.Adbc { @@ -157,108 +155,6 @@ public virtual void Dispose() { } - /// - /// Gets a value from the Arrow array at the specified index, - /// using the Field metadata for information. - /// - /// - /// The Arrow array. - /// - /// - /// The from the that can - /// be used for metadata inspection. - /// - /// - /// The index in the array to get the value from. - /// - public virtual object? GetValue(IArrowArray arrowArray, int index) - { - if (arrowArray == null) throw new ArgumentNullException(nameof(arrowArray)); - if (index < 0) throw new ArgumentOutOfRangeException(nameof(index)); - - switch (arrowArray) - { - case BooleanArray booleanArray: - return booleanArray.GetValue(index); - case Date32Array date32Array: - return date32Array.GetDateTime(index); - case Date64Array date64Array: - return date64Array.GetDateTime(index); - case Decimal128Array decimal128Array: - return decimal128Array.GetSqlDecimal(index); - case Decimal256Array decimal256Array: - return decimal256Array.GetString(index); - case DoubleArray doubleArray: - return doubleArray.GetValue(index); - case FloatArray floatArray: - return floatArray.GetValue(index); -#if NET5_0_OR_GREATER - case PrimitiveArray halfFloatArray: - return halfFloatArray.GetValue(index); -#endif - case Int8Array int8Array: - return int8Array.GetValue(index); - case Int16Array int16Array: - return int16Array.GetValue(index); - case Int32Array int32Array: - return int32Array.GetValue(index); - case Int64Array int64Array: - return int64Array.GetValue(index); - case StringArray stringArray: - return stringArray.GetString(index); -#if NET6_0_OR_GREATER - case Time32Array time32Array: - return time32Array.GetTime(index); - case Time64Array time64Array: - return time64Array.GetTime(index); -#else - case Time32Array time32Array: - int? time32 = time32Array.GetValue(index); - if (time32 == null) { return null; } - return ((Time32Type)time32Array.Data.DataType).Unit switch - { - TimeUnit.Second => TimeSpan.FromSeconds(time32.Value), - TimeUnit.Millisecond => TimeSpan.FromMilliseconds(time32.Value), - _ => throw new InvalidDataException("Unsupported time unit for Time32Type") - }; - case Time64Array time64Array: - long? time64 = time64Array.GetValue(index); - if (time64 == null) { return null; } - return ((Time64Type)time64Array.Data.DataType).Unit switch - { - TimeUnit.Microsecond => TimeSpan.FromTicks(time64.Value * 10), - TimeUnit.Nanosecond => TimeSpan.FromTicks(time64.Value / 100), - _ => throw new InvalidDataException("Unsupported time unit for Time64Type") - }; -#endif - case TimestampArray timestampArray: - return timestampArray.GetTimestamp(index); - case UInt8Array uInt8Array: - return uInt8Array.GetValue(index); - case UInt16Array uInt16Array: - return uInt16Array.GetValue(index); - case UInt32Array uInt32Array: - return uInt32Array.GetValue(index); - case UInt64Array uInt64Array: - return uInt64Array.GetValue(index); - - case BinaryArray binaryArray: - if (!binaryArray.IsNull(index)) - return binaryArray.GetBytes(index).ToArray(); - - return null; - - // not covered: - // -- struct array - // -- dictionary array - // -- fixed size binary - // -- list array - // -- union array - } - - return null; - } - /// /// Attempts to cancel an in-progress operation on a connection. /// diff --git a/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj b/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj index 0680015a19..8800081f8c 100644 --- a/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj +++ b/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj @@ -6,7 +6,8 @@ readme.md - + + diff --git a/csharp/src/Apache.Arrow.Adbc/Extensions/IArrowArrayExtensions.cs b/csharp/src/Apache.Arrow.Adbc/Extensions/IArrowArrayExtensions.cs new file mode 100644 index 0000000000..bcde4045f7 --- /dev/null +++ b/csharp/src/Apache.Arrow.Adbc/Extensions/IArrowArrayExtensions.cs @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.IO; +using System.Text.Json; +using Apache.Arrow.Types; + +namespace Apache.Arrow.Adbc.Extensions +{ + public static class IArrowArrayExtensions + { + /// + /// Helper extension to get a value from the at the specified index. + /// + /// + /// The Arrow array. + /// + /// + /// The index in the array to get the value from. + /// + public static object? ValueAt(this IArrowArray arrowArray, int index) + { + if (arrowArray == null) throw new ArgumentNullException(nameof(arrowArray)); + if (index < 0) throw new ArgumentOutOfRangeException(nameof(index)); + + switch (arrowArray) + { + case BooleanArray booleanArray: + return booleanArray.GetValue(index); + case Date32Array date32Array: + return date32Array.GetDateTime(index); + case Date64Array date64Array: + return date64Array.GetDateTime(index); + case Decimal128Array decimal128Array: + return decimal128Array.GetSqlDecimal(index); + case Decimal256Array decimal256Array: + return decimal256Array.GetString(index); + case DoubleArray doubleArray: + return doubleArray.GetValue(index); + case FloatArray floatArray: + return floatArray.GetValue(index); +#if NET5_0_OR_GREATER + case PrimitiveArray halfFloatArray: + return halfFloatArray.GetValue(index); +#endif + case Int8Array int8Array: + return int8Array.GetValue(index); + case Int16Array int16Array: + return int16Array.GetValue(index); + case Int32Array int32Array: + return int32Array.GetValue(index); + case Int64Array int64Array: + return int64Array.GetValue(index); + case StringArray stringArray: + return stringArray.GetString(index); +#if NET6_0_OR_GREATER + case Time32Array time32Array: + return time32Array.GetTime(index); + case Time64Array time64Array: + return time64Array.GetTime(index); +#else + case Time32Array time32Array: + int? time32 = time32Array.GetValue(index); + if (time32 == null) { return null; } + return ((Time32Type)time32Array.Data.DataType).Unit switch + { + TimeUnit.Second => TimeSpan.FromSeconds(time32.Value), + TimeUnit.Millisecond => TimeSpan.FromMilliseconds(time32.Value), + _ => throw new InvalidDataException("Unsupported time unit for Time32Type") + }; + case Time64Array time64Array: + long? time64 = time64Array.GetValue(index); + if (time64 == null) { return null; } + return ((Time64Type)time64Array.Data.DataType).Unit switch + { + TimeUnit.Microsecond => TimeSpan.FromTicks(time64.Value * 10), + TimeUnit.Nanosecond => TimeSpan.FromTicks(time64.Value / 100), + _ => throw new InvalidDataException("Unsupported time unit for Time64Type") + }; +#endif + case TimestampArray timestampArray: + return timestampArray.GetTimestamp(index); + case UInt8Array uInt8Array: + return uInt8Array.GetValue(index); + case UInt16Array uInt16Array: + return uInt16Array.GetValue(index); + case UInt32Array uInt32Array: + return uInt32Array.GetValue(index); + case UInt64Array uInt64Array: + return uInt64Array.GetValue(index); + case DayTimeIntervalArray dayTimeIntervalArray: + return dayTimeIntervalArray.GetValue(index); + case MonthDayNanosecondIntervalArray monthDayNanosecondIntervalArray: + return monthDayNanosecondIntervalArray.GetValue(index); + case YearMonthIntervalArray yearMonthIntervalArray: + return yearMonthIntervalArray.GetValue(index); + case BinaryArray binaryArray: + if (!binaryArray.IsNull(index)) + { + return binaryArray.GetBytes(index).ToArray(); + } + else + { + return null; + } + case ListArray listArray: + return listArray.GetSlicedValues(index); + case StructArray structArray: + return SerializeToJson(structArray, index); + + // not covered: + // -- map array + // -- dictionary array + // -- fixed size binary + // -- union array + } + + return null; + } + + /// + /// Converts a StructArray to a JSON string. + /// + private static string SerializeToJson(StructArray structArray, int index) + { + Dictionary? jsonDictionary = ParseStructArray(structArray, index); + + return JsonSerializer.Serialize(jsonDictionary); + } + + /// + /// Converts a StructArray to a Dictionary. + /// + private static Dictionary? ParseStructArray(StructArray structArray, int index) + { + if (structArray.IsNull(index)) + return null; + + Dictionary jsonDictionary = new Dictionary(); + StructType structType = (StructType)structArray.Data.DataType; + for (int i = 0; i < structArray.Data.Children.Length; i++) + { + string name = structType.Fields[i].Name; + object? value = ValueAt(structArray.Fields[i], index); + + if (value is StructArray structArray1) + { + List?> children = new List?>(); + + for (int j = 0; j < structArray1.Length; j++) + { + children.Add(ParseStructArray(structArray1, j)); + } + + if (children.Count > 0) + { + jsonDictionary.Add(name, children); + } + else + { + jsonDictionary.Add(name, ParseStructArray(structArray1, index)); + } + } + else if (value is IArrowArray arrowArray) + { + IList? values = CreateList(arrowArray); + + if (values != null) + { + for (int j = 0; j < arrowArray.Length; j++) + { + values.Add(ValueAt(arrowArray, j)); + } + + jsonDictionary.Add(name, values); + } + else + { + jsonDictionary.Add(name, new List()); + } + } + else + { + jsonDictionary.Add(name, value); + } + } + + return jsonDictionary; + } + + /// + /// Creates a List based on the type of the Arrow array. + /// + /// + /// + /// + private static IList? CreateList(IArrowArray arrowArray) + { + if (arrowArray == null) throw new ArgumentNullException(nameof(arrowArray)); + + switch (arrowArray) + { + case BooleanArray booleanArray: + return new List(); + case Date32Array date32Array: + case Date64Array date64Array: + return new List(); + case Decimal128Array decimal128Array: + return new List(); + case Decimal256Array decimal256Array: + return new List(); + case DoubleArray doubleArray: + return new List(); + case FloatArray floatArray: + return new List(); +#if NET5_0_OR_GREATER + case PrimitiveArray halfFloatArray: + return new List(); +#endif + case Int8Array int8Array: + return new List(); + case Int16Array int16Array: + return new List(); + case Int32Array int32Array: + return new List(); + case Int64Array int64Array: + return new List(); + case StringArray stringArray: + return new List(); +#if NET6_0_OR_GREATER + case Time32Array time32Array: + case Time64Array time64Array: + return new List(); +#else + case Time32Array time32Array: + case Time64Array time64Array: + return new List(); +#endif + case TimestampArray timestampArray: + return new List(); + case UInt8Array uInt8Array: + return new List(); + case UInt16Array uInt16Array: + return new List(); + case UInt32Array uInt32Array: + return new List(); + case UInt64Array uInt64Array: + return new List(); + + case BinaryArray binaryArray: + return new List(); + + // not covered: + // -- struct array + // -- dictionary array + // -- fixed size binary + // -- list array + // -- union array + } + + return null; + } + } +} diff --git a/csharp/src/Apache.Arrow.Adbc/Extensions/StandardSchemaExtensions.cs b/csharp/src/Apache.Arrow.Adbc/Extensions/StandardSchemaExtensions.cs index 3e95aac483..6b1dc48f99 100644 --- a/csharp/src/Apache.Arrow.Adbc/Extensions/StandardSchemaExtensions.cs +++ b/csharp/src/Apache.Arrow.Adbc/Extensions/StandardSchemaExtensions.cs @@ -53,6 +53,7 @@ public static IReadOnlyList Validate(this IReadOnlyList sche { Field field = schemaFields[i]; ArrayData dataField = data[i].Data; + if (field.DataType.TypeId != dataField.DataType.TypeId) { throw new ArgumentException($"Expecting data type {field.DataType} but found {data[i].Data.DataType} on field with name {field.Name}.", nameof(data)); @@ -65,10 +66,25 @@ public static IReadOnlyList Validate(this IReadOnlyList sche else if (field.DataType.TypeId == ArrowTypeId.List) { ListType listType = (ListType)field.DataType; - if (listType.Fields.Count > 0) + int j = 0; + Field f = listType.Fields[j]; + + List fieldsToValidate = new List(); + List arrayDataToValidate = new List(); + + ArrayData? child = j < dataField.Children.Length ? dataField.Children[j] : null; + + if (child != null) { - Validate(listType.Fields, dataField.Children.Select(e => new ContainerArray(e)).ToList()); + fieldsToValidate.Add(f); + arrayDataToValidate.Add(new ContainerArray(child)); } + else if (!f.IsNullable) + { + throw new InvalidOperationException("Received a null value for a non-nullable field"); + } + + Validate(fieldsToValidate, arrayDataToValidate); } else if (field.DataType.TypeId == ArrowTypeId.Union) { diff --git a/csharp/src/Apache.Arrow.Adbc/Properties/AssemblyInfo.cs b/csharp/src/Apache.Arrow.Adbc/Properties/AssemblyInfo.cs index 302a85450e..f9fe442ad2 100644 --- a/csharp/src/Apache.Arrow.Adbc/Properties/AssemblyInfo.cs +++ b/csharp/src/Apache.Arrow.Adbc/Properties/AssemblyInfo.cs @@ -17,4 +17,5 @@ [assembly: InternalsVisibleTo("Apache.Arrow.Adbc.Drivers.Apache, PublicKey=0024000004800000940000000602000000240000525341310004000001000100e504183f6d470d6b67b6d19212be3e1f598f70c246a120194bc38130101d0c1853e4a0f2232cb12e37a7a90e707aabd38511dac4f25fcb0d691b2aa265900bf42de7f70468fc997551a40e1e0679b605aa2088a4a69e07c117e988f5b1738c570ee66997fba02485e7856a49eca5fd0706d09899b8312577cbb9034599fc92d4")] [assembly: InternalsVisibleTo("Apache.Arrow.Adbc.Drivers.BigQuery, PublicKey=0024000004800000940000000602000000240000525341310004000001000100e504183f6d470d6b67b6d19212be3e1f598f70c246a120194bc38130101d0c1853e4a0f2232cb12e37a7a90e707aabd38511dac4f25fcb0d691b2aa265900bf42de7f70468fc997551a40e1e0679b605aa2088a4a69e07c117e988f5b1738c570ee66997fba02485e7856a49eca5fd0706d09899b8312577cbb9034599fc92d4")] +[assembly: InternalsVisibleTo("Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql, PublicKey=0024000004800000940000000602000000240000525341310004000001000100e504183f6d470d6b67b6d19212be3e1f598f70c246a120194bc38130101d0c1853e4a0f2232cb12e37a7a90e707aabd38511dac4f25fcb0d691b2aa265900bf42de7f70468fc997551a40e1e0679b605aa2088a4a69e07c117e988f5b1738c570ee66997fba02485e7856a49eca5fd0706d09899b8312577cbb9034599fc92d4")] [assembly: InternalsVisibleTo("Apache.Arrow.Adbc.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100e504183f6d470d6b67b6d19212be3e1f598f70c246a120194bc38130101d0c1853e4a0f2232cb12e37a7a90e707aabd38511dac4f25fcb0d691b2aa265900bf42de7f70468fc997551a40e1e0679b605aa2088a4a69e07c117e988f5b1738c570ee66997fba02485e7856a49eca5fd0706d09899b8312577cbb9034599fc92d4")] diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs index 5b3ed7c24a..a317ca19cc 100644 --- a/csharp/src/Client/AdbcCommand.cs +++ b/csharp/src/Client/AdbcCommand.cs @@ -16,9 +16,15 @@ */ using System; +using System.Collections; +using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.Data.SqlTypes; +using System.Globalization; +using System.Linq; using System.Threading.Tasks; +using Apache.Arrow.Types; namespace Apache.Arrow.Adbc.Client { @@ -27,9 +33,11 @@ namespace Apache.Arrow.Adbc.Client /// public sealed class AdbcCommand : DbCommand { - private AdbcStatement adbcStatement; + private readonly AdbcStatement _adbcStatement; + private AdbcParameterCollection? _dbParameterCollection; private int _timeout = 30; private bool _disposed; + private string? _commandTimeoutProperty; /// /// Overloaded. Initializes . @@ -45,7 +53,8 @@ public AdbcCommand(AdbcConnection adbcConnection) : base() this.DbConnection = adbcConnection; this.DecimalBehavior = adbcConnection.DecimalBehavior; - this.adbcStatement = adbcConnection.CreateStatement(); + this.StructBehavior = adbcConnection.StructBehavior; + this._adbcStatement = adbcConnection.CreateStatement(); } /// @@ -61,29 +70,39 @@ public AdbcCommand(string query, AdbcConnection adbcConnection) : base() if (adbcConnection == null) throw new ArgumentNullException(nameof(adbcConnection)); - this.adbcStatement = adbcConnection.CreateStatement(); + this._adbcStatement = adbcConnection.CreateStatement(); this.CommandText = query; this.DbConnection = adbcConnection; this.DecimalBehavior = adbcConnection.DecimalBehavior; + this.StructBehavior = adbcConnection.StructBehavior; } // For testing internal AdbcCommand(AdbcStatement adbcStatement, AdbcConnection adbcConnection) { - this.adbcStatement = adbcStatement; + this._adbcStatement = adbcStatement; this.DbConnection = adbcConnection; this.DecimalBehavior = adbcConnection.DecimalBehavior; + this.StructBehavior = adbcConnection.StructBehavior; + + if (adbcConnection.CommandTimeoutValue != null) + { + this.AdbcCommandTimeoutProperty = adbcConnection.CommandTimeoutValue.DriverPropertyName; + this.CommandTimeout = adbcConnection.CommandTimeoutValue.Value; + } } /// /// Gets the associated with /// this . /// - public AdbcStatement AdbcStatement => _disposed ? throw new ObjectDisposedException(nameof(AdbcCommand)) : this.adbcStatement; + public AdbcStatement AdbcStatement => _disposed ? throw new ObjectDisposedException(nameof(AdbcCommand)) : this._adbcStatement; public DecimalBehavior DecimalBehavior { get; set; } + public StructBehavior StructBehavior { get; set; } + public override string CommandText { get => AdbcStatement.SqlQuery ?? string.Empty; @@ -108,10 +127,43 @@ public override CommandType CommandType } } + /// + /// Gets or sets the name of the command timeout property for the underlying ADBC driver. + /// + public string AdbcCommandTimeoutProperty + { + get + { + if (string.IsNullOrEmpty(_commandTimeoutProperty)) + throw new InvalidOperationException("CommandTimeoutProperty is not set."); + + return _commandTimeoutProperty!; + } + set => _commandTimeoutProperty = value; + } + public override int CommandTimeout { get => _timeout; - set => _timeout = value; + set + { + // ensures the property exists before setting the CommandTimeout value + string property = AdbcCommandTimeoutProperty; + _adbcStatement.SetOption(property, value.ToString(CultureInfo.InvariantCulture)); + _timeout = value; + } + } + + protected override DbParameterCollection DbParameterCollection + { + get + { + if (_dbParameterCollection == null) + { + _dbParameterCollection = new AdbcParameterCollection(); + } + return _dbParameterCollection; + } } /// @@ -127,6 +179,7 @@ public byte[]? SubstraitPlan public override int ExecuteNonQuery() { + BindParameters(); return Convert.ToInt32(AdbcStatement.ExecuteUpdate().AffectedRows); } @@ -137,6 +190,7 @@ public override int ExecuteNonQuery() /// public long ExecuteUpdate() { + BindParameters(); return AdbcStatement.ExecuteUpdate().AffectedRows; } @@ -146,6 +200,7 @@ public long ExecuteUpdate() /// public QueryResult ExecuteQuery() { + BindParameters(); QueryResult executed = AdbcStatement.ExecuteQuery(); return executed; @@ -183,7 +238,7 @@ protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) case CommandBehavior.SchemaOnly: // The schema is not known until a read happens case CommandBehavior.Default: QueryResult result = this.ExecuteQuery(); - return new AdbcDataReader(this, result, this.DecimalBehavior, closeConnection); + return new AdbcDataReader(this, result, this.DecimalBehavior, this.StructBehavior, closeConnection); default: throw new InvalidOperationException($"{behavior} is not supported with this provider"); @@ -195,13 +250,239 @@ protected override void Dispose(bool disposing) if (disposing && !_disposed) { // TODO: ensure not in the middle of pulling - this.adbcStatement.Dispose(); + this._adbcStatement.Dispose(); _disposed = true; } base.Dispose(disposing); } + private void BindParameters() + { + if (_dbParameterCollection?.Count > 0) + { + Field[] fields = new Field[_dbParameterCollection.Count]; + IArrowArray[] parameters = new IArrowArray[_dbParameterCollection.Count]; + for (int i = 0; i < fields.Length; i++) + { + AdbcParameter param = (AdbcParameter)_dbParameterCollection[i]; + switch (param.DbType) + { + case DbType.Binary: + var binaryBuilder = new BinaryArray.Builder(); + switch (param.Value) + { + case null: binaryBuilder.AppendNull(); break; + case byte[] array: binaryBuilder.Append(array.AsSpan()); break; + default: throw new NotSupportedException($"Values of type {param.Value.GetType().Name} cannot be bound as binary"); + } + parameters[i] = binaryBuilder.Build(); + break; + case DbType.Boolean: + var boolBuilder = new BooleanArray.Builder(); + switch (param.Value) + { + case null: boolBuilder.AppendNull(); break; + case bool boolValue: boolBuilder.Append(boolValue); break; + default: boolBuilder.Append(ConvertValue(param.Value, Convert.ToBoolean, DbType.Boolean)); break; + } + parameters[i] = boolBuilder.Build(); + break; + case DbType.Byte: + var uint8Builder = new UInt8Array.Builder(); + switch (param.Value) + { + case null: uint8Builder.AppendNull(); break; + case byte byteValue: uint8Builder.Append(byteValue); break; + default: uint8Builder.Append(ConvertValue(param.Value, Convert.ToByte, DbType.Byte)); break; + } + parameters[i] = uint8Builder.Build(); + break; + case DbType.Date: + var dateBuilder = new Date32Array.Builder(); + switch (param.Value) + { + case null: dateBuilder.AppendNull(); break; + case DateTime datetime: dateBuilder.Append(datetime); break; +#if NET5_0_OR_GREATER + case DateOnly dateonly: dateBuilder.Append(dateonly); break; +#endif + default: dateBuilder.Append(ConvertValue(param.Value, Convert.ToDateTime, DbType.Date)); break; + } + parameters[i] = dateBuilder.Build(); + break; + case DbType.DateTime: + var timestampBuilder = new TimestampArray.Builder(); + switch (param.Value) + { + case null: timestampBuilder.AppendNull(); break; + case DateTime datetime: timestampBuilder.Append(datetime); break; + default: timestampBuilder.Append(ConvertValue(param.Value, Convert.ToDateTime, DbType.DateTime)); break; + } + parameters[i] = timestampBuilder.Build(); + break; + case DbType.Decimal: + var value = param.Value switch + { + null => (SqlDecimal?)null, + SqlDecimal sqlDecimal => sqlDecimal, + decimal d => new SqlDecimal(d), + _ => new SqlDecimal(ConvertValue(param.Value, Convert.ToDecimal, DbType.Decimal)), + }; + var decimalBuilder = new Decimal128Array.Builder(new Decimal128Type(value?.Precision ?? 10, value?.Scale ?? 0)); + if (value is null) + { + decimalBuilder.AppendNull(); + } + else + { + decimalBuilder.Append(value.Value); + } + parameters[i] = decimalBuilder.Build(); + break; + case DbType.Double: + var doubleBuilder = new DoubleArray.Builder(); + switch (param.Value) + { + case null: doubleBuilder.AppendNull(); break; + case double dbl: doubleBuilder.Append(dbl); break; + default: doubleBuilder.Append(ConvertValue(param.Value, Convert.ToDouble, DbType.Double)); break; + } + parameters[i] = doubleBuilder.Build(); + break; + case DbType.Int16: + var int16Builder = new Int16Array.Builder(); + switch (param.Value) + { + case null: int16Builder.AppendNull(); break; + case short shortValue: int16Builder.Append(shortValue); break; + default: int16Builder.Append(ConvertValue(param.Value, Convert.ToInt16, DbType.Int16)); break; + } + parameters[i] = int16Builder.Build(); + break; + case DbType.Int32: + var int32Builder = new Int32Array.Builder(); + switch (param.Value) + { + case null: int32Builder.AppendNull(); break; + case int intValue: int32Builder.Append(intValue); break; + default: int32Builder.Append(ConvertValue(param.Value, Convert.ToInt32, DbType.Int32)); break; + } + parameters[i] = int32Builder.Build(); + break; + case DbType.Int64: + var int64Builder = new Int64Array.Builder(); + switch (param.Value) + { + case null: int64Builder.AppendNull(); break; + case long longValue: int64Builder.Append(longValue); break; + default: int64Builder.Append(ConvertValue(param.Value, Convert.ToInt64, DbType.Int64)); break; + } + parameters[i] = int64Builder.Build(); + break; + case DbType.SByte: + var int8Builder = new Int8Array.Builder(); + switch (param.Value) + { + case null: int8Builder.AppendNull(); break; + case sbyte sbyteValue: int8Builder.Append(sbyteValue); break; + default: int8Builder.Append(ConvertValue(param.Value, Convert.ToSByte, DbType.SByte)); break; + } + parameters[i] = int8Builder.Build(); + break; + case DbType.Single: + var floatBuilder = new FloatArray.Builder(); + switch (param.Value) + { + case null: floatBuilder.AppendNull(); break; + case float floatValue: floatBuilder.Append(floatValue); break; + default: floatBuilder.Append(ConvertValue(param.Value, Convert.ToSingle, DbType.Single)); break; + } + parameters[i] = floatBuilder.Build(); + break; + case DbType.String: + var stringBuilder = new StringArray.Builder(); + switch (param.Value) + { + case null: stringBuilder.AppendNull(); break; + case string stringValue: stringBuilder.Append(stringValue); break; + default: stringBuilder.Append(ConvertValue(param.Value, Convert.ToString, DbType.String)); break; + } + parameters[i] = stringBuilder.Build(); + break; + case DbType.Time: + var timeBuilder = new Time32Array.Builder(); + switch (param.Value) + { + case null: timeBuilder.AppendNull(); break; + case DateTime datetime: timeBuilder.Append((int)(datetime.TimeOfDay.Ticks / TimeSpan.TicksPerMillisecond)); break; +#if NET5_0_OR_GREATER + case TimeOnly timeonly: timeBuilder.Append(timeonly); break; +#endif + default: + DateTime convertedDateTime = ConvertValue(param.Value, Convert.ToDateTime, DbType.Time); + timeBuilder.Append((int)(convertedDateTime.TimeOfDay.Ticks / TimeSpan.TicksPerMillisecond)); + break; + } + parameters[i] = timeBuilder.Build(); + break; + case DbType.UInt16: + var uint16Builder = new UInt16Array.Builder(); + switch (param.Value) + { + case null: uint16Builder.AppendNull(); break; + case ushort ushortValue: uint16Builder.Append(ushortValue); break; + default: uint16Builder.Append(ConvertValue(param.Value, Convert.ToUInt16, DbType.UInt16)); break; + } + parameters[i] = uint16Builder.Build(); + break; + case DbType.UInt32: + var uint32Builder = new UInt32Array.Builder(); + switch (param.Value) + { + case null: uint32Builder.AppendNull(); break; + case uint uintValue: uint32Builder.Append(uintValue); break; + default: uint32Builder.Append(ConvertValue(param.Value, Convert.ToUInt32, DbType.UInt32)); break; + } + parameters[i] = uint32Builder.Build(); + break; + case DbType.UInt64: + var uint64Builder = new UInt64Array.Builder(); + switch (param.Value) + { + case null: uint64Builder.AppendNull(); break; + case ulong ulongValue: uint64Builder.Append(ulongValue); break; + default: uint64Builder.Append(ConvertValue(param.Value, Convert.ToUInt64, DbType.UInt64)); break; + } + parameters[i] = uint64Builder.Build(); + break; + default: + throw new NotSupportedException($"Parameters of type {param.DbType} are not supported"); + } + + fields[i] = new Field( + string.IsNullOrWhiteSpace(param.ParameterName) ? Guid.NewGuid().ToString() : param.ParameterName, + parameters[i].Data.DataType, + param.IsNullable || param.Value == null); + } + + Schema schema = new Schema(fields, null); + AdbcStatement.Bind(new RecordBatch(schema, parameters, 1), schema); + } + } + + private static T ConvertValue(object value, Func converter, DbType type) + { + try + { + return converter(value); + } + catch (Exception) + { + throw new NotSupportedException($"Values of type {value.GetType().Name} cannot be bound as {type}."); + } + } + #if NET5_0_OR_GREATER public override ValueTask DisposeAsync() { @@ -214,8 +495,6 @@ public override ValueTask DisposeAsync() public override UpdateRowSource UpdatedRowSource { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } - protected override DbParameterCollection DbParameterCollection => throw new NotImplementedException(); - protected override DbTransaction? DbTransaction { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } public override void Cancel() @@ -235,9 +514,60 @@ public override void Prepare() protected override DbParameter CreateDbParameter() { - throw new NotImplementedException(); + return new AdbcParameter(); } #endregion + + private class AdbcParameterCollection : DbParameterCollection + { + readonly List _parameters = new List(); + + public override int Count => _parameters.Count; + + public override object SyncRoot => throw new NotImplementedException(); + + public override int Add(object value) + { + int result = _parameters.Count; + _parameters.Add((AdbcParameter)value); + return result; + } + + public override void AddRange(System.Array values) => _parameters.AddRange(values.Cast()); + public override void Clear() => _parameters.Clear(); + public override bool Contains(object value) => _parameters.Contains((AdbcParameter)value); + public override bool Contains(string value) => IndexOf(value) >= 0; + public override void CopyTo(System.Array array, int index) => throw new NotImplementedException(); + public override IEnumerator GetEnumerator() => _parameters.GetEnumerator(); + public override int IndexOf(object value) => _parameters.IndexOf((AdbcParameter)value); + public override int IndexOf(string parameterName) => GetParameterIndex(parameterName, throwOnFailure: false); + public override void Insert(int index, object value) => _parameters.Insert(index, (AdbcParameter)value); + public override void Remove(object value) => _parameters.Remove((AdbcParameter)value); + public override void RemoveAt(int index) => _parameters.RemoveAt(index); + public override void RemoveAt(string parameterName) => _parameters.RemoveAt(GetParameterIndex(parameterName)); + protected override DbParameter GetParameter(int index) => _parameters[index]; + protected override DbParameter GetParameter(string parameterName) => _parameters[GetParameterIndex(parameterName)]; + protected override void SetParameter(int index, DbParameter value) => _parameters[index] = (AdbcParameter)value; + protected override void SetParameter(string parameterName, DbParameter value) => throw new NotImplementedException(); + + private int GetParameterIndex(string parameterName, bool throwOnFailure = true) + { + for (int i = 0; i < _parameters.Count; i++) + { + if (parameterName == _parameters[i].ParameterName) + { + return i; + } + } + + if (throwOnFailure) + { + throw new IndexOutOfRangeException("parameterName not found"); + } + + return -1; + } + } } } diff --git a/csharp/src/Client/AdbcConnection.cs b/csharp/src/Client/AdbcConnection.cs index 5347ba7798..a8d805f7e2 100644 --- a/csharp/src/Client/AdbcConnection.cs +++ b/csharp/src/Client/AdbcConnection.cs @@ -34,6 +34,7 @@ public sealed class AdbcConnection : DbConnection { private AdbcDatabase? adbcDatabase; private Adbc.AdbcConnection? adbcConnectionInternal; + private TimeoutValue? connectionTimeoutValue; private readonly Dictionary adbcConnectionParameters; private readonly Dictionary adbcConnectionOptions; @@ -122,6 +123,8 @@ internal AdbcStatement CreateStatement() return this.adbcConnectionInternal!.CreateStatement(); } + internal TimeoutValue? CommandTimeoutValue { get; private set; } + #if NET5_0_OR_GREATER [AllowNull] #endif @@ -132,11 +135,35 @@ internal AdbcStatement CreateStatement() /// public DecimalBehavior DecimalBehavior { get; set; } + /// + /// Indicates how structs should be treated. + /// + public StructBehavior StructBehavior { get; set; } = StructBehavior.JsonString; + + public override int ConnectionTimeout + { + get + { + if (connectionTimeoutValue != null) + return connectionTimeoutValue.Value; + else + return base.ConnectionTimeout; + } + } + protected override DbCommand CreateDbCommand() { EnsureConnectionOpen(); - return new AdbcCommand(this); + AdbcCommand cmd = new AdbcCommand(this); + + if (CommandTimeoutValue != null) + { + cmd.AdbcCommandTimeoutProperty = CommandTimeoutValue.DriverPropertyName!; + cmd.CommandTimeout = CommandTimeoutValue.Value; + } + + return cmd; } /// @@ -232,7 +259,27 @@ private void SetConnectionProperties(string value) object? builderValue = builder[key]; if (builderValue != null) { - this.adbcConnectionParameters.Add(key, Convert.ToString(builderValue)!); + string paramValue = Convert.ToString(builderValue)!; + + switch (key) + { + case ConnectionStringKeywords.DecimalBehavior: + this.DecimalBehavior = (DecimalBehavior)Enum.Parse(typeof(DecimalBehavior), paramValue); + break; + case ConnectionStringKeywords.StructBehavior: + this.StructBehavior = (StructBehavior)Enum.Parse(typeof(StructBehavior), paramValue); + break; + case ConnectionStringKeywords.CommandTimeout: + CommandTimeoutValue = ConnectionStringParser.ParseTimeoutValue(paramValue); + break; + case ConnectionStringKeywords.ConnectionTimeout: + this.connectionTimeoutValue = ConnectionStringParser.ParseTimeoutValue(paramValue); + this.adbcConnectionParameters[connectionTimeoutValue.DriverPropertyName] = connectionTimeoutValue.Value.ToString(); + break; + default: + this.adbcConnectionParameters.Add(key, paramValue); + break; + } } } } diff --git a/csharp/src/Client/AdbcDataReader.cs b/csharp/src/Client/AdbcDataReader.cs index 124399865b..219e25e90a 100644 --- a/csharp/src/Client/AdbcDataReader.cs +++ b/csharp/src/Client/AdbcDataReader.cs @@ -26,10 +26,19 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Extensions; using Apache.Arrow.Types; namespace Apache.Arrow.Adbc.Client { + /// + /// Invoked when a value is read from an Arrow array. + /// + /// The Arrow array. + /// The item index. + /// The value at the index. + public delegate object GetValueEventHandler(IArrowArray arrowArray, int index); + /// /// Represents a DbDataReader over Arrow record batches /// @@ -44,7 +53,15 @@ public sealed class AdbcDataReader : DbDataReader, IDbColumnSchemaGenerator private bool isClosed; private int recordsAffected = -1; - internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult, DecimalBehavior decimalBehavior, bool closeConnection) + /// + /// An event that is raised when a value is read from an IArrowArray. + /// + /// + /// Callers may opt to provide overrides for parsing values. + /// + public event GetValueEventHandler? OnGetValue; + + internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult, DecimalBehavior decimalBehavior, StructBehavior structBehavior, bool closeConnection) { if (adbcCommand == null) throw new ArgumentNullException(nameof(adbcCommand)); @@ -65,11 +82,12 @@ internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult, De this.closeConnection = closeConnection; this.isClosed = false; this.DecimalBehavior = decimalBehavior; + this.StructBehavior = structBehavior; } public override object this[int ordinal] => GetValue(ordinal); - public override object this[string name] => GetValue(this.RecordBatch.Column(name), GetOrdinal(name)) ?? DBNull.Value; + public override object this[string name] => GetValue(this.RecordBatch.Column(name)) ?? DBNull.Value; public override int Depth => 0; @@ -86,6 +104,8 @@ internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult, De public DecimalBehavior DecimalBehavior { get; set; } + public StructBehavior StructBehavior { get; set; } + public override int RecordsAffected => this.recordsAffected; /// @@ -230,7 +250,7 @@ public override string GetString(int ordinal) public override object GetValue(int ordinal) { - object? value = GetValue(this.RecordBatch.Column(ordinal), ordinal); + object? value = GetValue(this.RecordBatch.Column(ordinal)); if (value == null) return DBNull.Value; @@ -301,7 +321,7 @@ public override bool Read() public override DataTable? GetSchemaTable() { - return SchemaConverter.ConvertArrowSchema(this.schema, this.adbcCommand.AdbcStatement, this.DecimalBehavior); + return SchemaConverter.ConvertArrowSchema(this.schema, this.adbcCommand.AdbcStatement, this.DecimalBehavior, this.StructBehavior); } #if NET5_0_OR_GREATER @@ -321,7 +341,7 @@ public ReadOnlyCollection GetAdbcColumnSchema() foreach (Field f in this.schema.FieldsList) { - Type t = SchemaConverter.ConvertArrowType(f, this.DecimalBehavior); + Type t = SchemaConverter.ConvertArrowType(f, this.DecimalBehavior, this.StructBehavior); if (f.HasMetadata && f.Metadata.ContainsKey("precision") && @@ -341,14 +361,19 @@ public ReadOnlyCollection GetAdbcColumnSchema() } /// - /// Gets the value from an IArrowArray at the current row index + /// Gets the value from an IArrowArray at the current row index. /// /// /// - public object? GetValue(IArrowArray arrowArray, int ordinal) + private object? GetValue(IArrowArray arrowArray) { - Field field = this.schema.GetFieldByIndex(ordinal); - return this.adbcCommand.AdbcStatement.GetValue(arrowArray, this.currentRowInRecordBatch); + // if the OnGetValue event is set, call it + object? result = OnGetValue?.Invoke(arrowArray, this.currentRowInRecordBatch); + + // if the value is null, try to get the value from the ArrowArray + result = result ?? arrowArray.ValueAt(this.currentRowInRecordBatch); + + return result; } /// diff --git a/csharp/src/Client/AdbcParameter.cs b/csharp/src/Client/AdbcParameter.cs new file mode 100644 index 0000000000..c816b1a0b8 --- /dev/null +++ b/csharp/src/Client/AdbcParameter.cs @@ -0,0 +1,48 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Data.Common; +using System.Data; +using System.Diagnostics.CodeAnalysis; +using System; + +namespace Apache.Arrow.Adbc.Client +{ + sealed public class AdbcParameter : DbParameter + { + public override DbType DbType { get; set; } + public override ParameterDirection Direction + { + get => ParameterDirection.Input; + set { if (value != ParameterDirection.Input) { throw new NotSupportedException(); } } + } + public override bool IsNullable { get; set; } = true; +#if NET5_0_OR_GREATER + [AllowNull] +#endif + public override string ParameterName { get; set; } = string.Empty; + public override int Size { get; set; } +#if NET5_0_OR_GREATER + [AllowNull] +#endif + public override string SourceColumn { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + public override bool SourceColumnNullMapping { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + public override object? Value { get; set; } + + public override void ResetDbType() => throw new NotImplementedException(); + } +} diff --git a/csharp/src/Client/ConnectionStringParser.cs b/csharp/src/Client/ConnectionStringParser.cs new file mode 100644 index 0000000000..3fb602870d --- /dev/null +++ b/csharp/src/Client/ConnectionStringParser.cs @@ -0,0 +1,81 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Text.RegularExpressions; + +namespace Apache.Arrow.Adbc.Client +{ + internal class ConnectionStringKeywords + { + public const string ConnectionTimeout = "adbcconnectiontimeout"; + public const string CommandTimeout = "adbccommandtimeout"; + public const string StructBehavior = "structbehavior"; + public const string DecimalBehavior = "decimalbehavior"; + } + + internal class ConnectionStringParser + { + public static TimeoutValue ParseTimeoutValue(string value) + { + string pattern = @"\(([^,]+),\s*([^,]+),\s*([^,]+)\)"; + + // Match the regex + Match match = Regex.Match(value, pattern); + + if (match.Success) + { + string driverPropertyName = match.Groups[1].Value.Trim(); + string timeoutAsString = match.Groups[2].Value.Trim(); + string units = match.Groups[3].Value.Trim(); + + if (units != "s" && units != "ms") + { + throw new InvalidOperationException("invalid units"); + } + + TimeoutValue timeoutValue = new TimeoutValue + { + DriverPropertyName = driverPropertyName, + Value = int.Parse(timeoutAsString), + Units = units + }; + + return timeoutValue; + } + else + { + throw new ArgumentOutOfRangeException(nameof(value)); + } + } + } + + internal class TimeoutValue + { + public string DriverPropertyName { get; set; } = string.Empty; + + public int Value { get; set; } + + // seconds=s + // milliseconds=ms + /// + /// While these can be helpful, the DbConnection and DbCommand + /// objects limit the use of these. + /// + public string Units { get; set; } = string.Empty; + } +} diff --git a/csharp/src/Client/SchemaConverter.cs b/csharp/src/Client/SchemaConverter.cs index 1d43207b2b..14e1058c0a 100644 --- a/csharp/src/Client/SchemaConverter.cs +++ b/csharp/src/Client/SchemaConverter.cs @@ -19,6 +19,7 @@ using System.Data; using System.Data.Common; using System.Data.SqlTypes; +using Apache.Arrow.Scalars; using Apache.Arrow.Types; namespace Apache.Arrow.Adbc.Client @@ -32,7 +33,7 @@ internal class SchemaConverter /// The Arrow schema /// The AdbcStatement to use /// - public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStatement, DecimalBehavior decimalBehavior) + public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStatement, DecimalBehavior decimalBehavior, StructBehavior structBehavior) { if (schema == null) throw new ArgumentNullException(nameof(schema)); @@ -60,7 +61,7 @@ public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStat row[SchemaTableColumn.ColumnOrdinal] = columnOrdinal; row[SchemaTableColumn.AllowDBNull] = f.IsNullable; row[SchemaTableColumn.ProviderType] = f.DataType; - Type t = ConvertArrowType(f, decimalBehavior); + Type t = ConvertArrowType(f, decimalBehavior, structBehavior); row[SchemaTableColumn.DataType] = t; @@ -110,7 +111,7 @@ public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStat /// /// /// - public static Type ConvertArrowType(Field f, DecimalBehavior decimalBehavior) + public static Type ConvertArrowType(Field f, DecimalBehavior decimalBehavior, StructBehavior structBehavior) { switch (f.DataType.TypeId) { @@ -119,11 +120,11 @@ public static Type ConvertArrowType(Field f, DecimalBehavior decimalBehavior) IArrowType valueType = list.ValueDataType; return GetArrowArrayType(valueType); default: - return GetArrowType(f, decimalBehavior); + return GetArrowType(f, decimalBehavior, structBehavior); } } - public static Type GetArrowType(Field f, DecimalBehavior decimalBehavior) + public static Type GetArrowType(Field f, DecimalBehavior decimalBehavior, StructBehavior structBehavior) { switch (f.DataType.TypeId) { @@ -182,7 +183,10 @@ public static Type GetArrowType(Field f, DecimalBehavior decimalBehavior) return typeof(string); case ArrowTypeId.Struct: - goto default; + if (structBehavior == StructBehavior.JsonString) + return typeof(string); + else + goto default; case ArrowTypeId.Timestamp: return typeof(DateTimeOffset); @@ -190,6 +194,17 @@ public static Type GetArrowType(Field f, DecimalBehavior decimalBehavior) case ArrowTypeId.Null: return typeof(DBNull); + case ArrowTypeId.Interval: + switch (((IntervalType)f.DataType).Unit) { + case IntervalUnit.MonthDayNanosecond: + return typeof(MonthDayNanosecondInterval); + case IntervalUnit.DayTime: + return typeof(DayTimeInterval); + case IntervalUnit.YearMonth: + return typeof(YearMonthInterval); + } + goto default; + default: return f.DataType.GetType(); } diff --git a/csharp/src/Client/StructBehavior.cs b/csharp/src/Client/StructBehavior.cs new file mode 100644 index 0000000000..9911a1aecb --- /dev/null +++ b/csharp/src/Client/StructBehavior.cs @@ -0,0 +1,32 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +namespace Apache.Arrow.Adbc.Client +{ + public enum StructBehavior + { + /// + /// Serialized as a JSON string + /// + JsonString, + + /// + /// Leave as native StructArray + /// + Strict + } +} diff --git a/csharp/src/Client/readme.md b/csharp/src/Client/readme.md index 2c4a51aca3..59f647a37e 100644 --- a/csharp/src/Client/readme.md +++ b/csharp/src/Client/readme.md @@ -73,3 +73,12 @@ if using the default user name and password authentication, but look like when using JWT authentication with an unencrypted key file. Other ADBC drivers will have different connection parameters, so be sure to check the documentation for each driver. + +### Connection Keywords +Because the ADO.NET client is designed to work with multiple drivers, callers will need to specify the driver properties that are set for particular values. This can be done either as properties on the objects directly, or can be parsed from the connection string. +These properties are: + +- __AdbcConnectionTimeout__ - This specifies the connection timeout value. The value needs to be in the form (driver.property.name, integer, unit) where the unit is one of `s` or `ms`, For example, `AdbcConnectionTimeout=(adbc.snowflake.sql.client_option.client_timeout,30,s)` would set the connection timeout to 30 seconds. +- __AdbcCommandTimeout__ - This specifies the command timeout value. This follows the same pattern as `AdbcConnectionTimeout` and sets the `AdbcCommandTimeoutProperty` and `CommandTimeout` values on the `AdbcCommand` object. +- __StructBehavior__ - This specifies the StructBehavior when working with Arrow Struct arrays. The valid values are `JsonString` (the default) or `Strict` (treat the struct as a native type). +- __DecimalBehavior__ - This specifies the DecimalBehavior when parsing decimal values from Arrow libraries. The valid values are `UseSqlDecimal` or `OverflowDecimalAsString` where values like Decimal256 are treated as strings. diff --git a/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj b/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj index 8d075bdb2e..13efc4e0ba 100644 --- a/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj +++ b/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj @@ -5,7 +5,9 @@ - + + + diff --git a/csharp/src/Drivers/Apache/ApacheParameters.cs b/csharp/src/Drivers/Apache/ApacheParameters.cs new file mode 100644 index 0000000000..17c94be32a --- /dev/null +++ b/csharp/src/Drivers/Apache/ApacheParameters.cs @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace Apache.Arrow.Adbc.Drivers.Apache +{ + /// + /// Options common to all Apache drivers. + /// + public class ApacheParameters + { + public const string PollTimeMilliseconds = "adbc.apache.statement.polltime_ms"; + public const string BatchSize = "adbc.apache.statement.batch_size"; + public const string QueryTimeoutSeconds = "adbc.apache.statement.query_timeout_s"; + } +} diff --git a/csharp/src/Drivers/Apache/ApacheUtility.cs b/csharp/src/Drivers/Apache/ApacheUtility.cs new file mode 100644 index 0000000000..f1cb07e07f --- /dev/null +++ b/csharp/src/Drivers/Apache/ApacheUtility.cs @@ -0,0 +1,141 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Threading; + +namespace Apache.Arrow.Adbc.Drivers.Apache +{ + internal class ApacheUtility + { + internal const int QueryTimeoutSecondsDefault = 60; + + public enum TimeUnit + { + Seconds, + Milliseconds + } + + public static CancellationToken GetCancellationToken(int timeout, TimeUnit timeUnit) + { + TimeSpan span; + + if (timeout == 0 || timeout == int.MaxValue) + { + // the max TimeSpan for CancellationTokenSource is int.MaxValue in milliseconds (not TimeSpan.MaxValue) + // no matter what the unit is + span = TimeSpan.FromMilliseconds(int.MaxValue); + } + else + { + if (timeUnit == TimeUnit.Seconds) + { + span = TimeSpan.FromSeconds(timeout); + } + else + { + span = TimeSpan.FromMilliseconds(timeout); + } + } + + return GetCancellationToken(span); + } + + private static CancellationToken GetCancellationToken(TimeSpan timeSpan) + { + var cts = new CancellationTokenSource(timeSpan); + return cts.Token; + } + + public static bool QueryTimeoutIsValid(string key, string value, out int queryTimeoutSeconds) + { + if (!string.IsNullOrEmpty(value) && int.TryParse(value, out int queryTimeout) && (queryTimeout >= 0)) + { + queryTimeoutSeconds = queryTimeout; + return true; + } + else + { + throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value of 0 (infinite) or greater."); + } + } + + public static bool ContainsException(Exception exception, out T? containedException) where T : Exception + { + if (exception is AggregateException aggregateException) + { + foreach (Exception? ex in aggregateException.InnerExceptions) + { + if (ex is T ce) + { + containedException = ce; + return true; + } + } + } + + Exception? e = exception; + while (e != null) + { + if (e is T ce) + { + containedException = ce; + return true; + } + e = e.InnerException; + } + + containedException = null; + return false; + } + + public static bool ContainsException(Exception exception, Type? exceptionType, out Exception? containedException) + { + if (exception == null || exceptionType == null) + { + containedException = null; + return false; + } + + if (exception is AggregateException aggregateException) + { + foreach (Exception? ex in aggregateException.InnerExceptions) + { + if (exceptionType.IsInstanceOfType(ex)) + { + containedException = ex; + return true; + } + } + } + + Exception? e = exception; + while (e != null) + { + if (exceptionType.IsInstanceOfType(e)) + { + containedException = e; + return true; + } + e = e.InnerException; + } + + containedException = null; + return false; + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/DataTypeConversion.cs b/csharp/src/Drivers/Apache/Hive2/DataTypeConversion.cs new file mode 100644 index 0000000000..0d0865d7ab --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/DataTypeConversion.cs @@ -0,0 +1,66 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + [Flags] + internal enum DataTypeConversion + { + Empty = 0, + None = 1, + Scalar = 2, + } + + internal static class DataTypeConversionParser + { + internal const string SupportedList = DataTypeConversionOptions.None + ", " + DataTypeConversionOptions.Scalar; + + internal static DataTypeConversion Parse(string? dataTypeConversion) + { + DataTypeConversion result = DataTypeConversion.Empty; + + if (string.IsNullOrWhiteSpace(dataTypeConversion)) + { + // Default + return DataTypeConversion.Scalar; + } + + string[] conversions = dataTypeConversion!.Split(','); + foreach (string? conversion in conversions) + { + result |= (conversion?.Trim().ToLowerInvariant()) switch + { + null or "" => DataTypeConversion.Empty, + DataTypeConversionOptions.None => DataTypeConversion.None, + DataTypeConversionOptions.Scalar => DataTypeConversion.Scalar, + _ => throw new ArgumentOutOfRangeException(nameof(dataTypeConversion), conversion, "Invalid or unsupported data type conversion"), + }; + } + + if (result.HasFlag(DataTypeConversion.None) && result.HasFlag(DataTypeConversion.Scalar)) + { + throw new ArgumentOutOfRangeException(nameof(dataTypeConversion), dataTypeConversion, "Conflicting data type conversion options"); + } + // Default + if (result == DataTypeConversion.Empty) result = DataTypeConversion.Scalar; + + return result; + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/DecimalUtility.cs b/csharp/src/Drivers/Apache/Hive2/DecimalUtility.cs new file mode 100644 index 0000000000..8c7d076bda --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/DecimalUtility.cs @@ -0,0 +1,470 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Buffers.Text; +using System.Numerics; +using System.Text; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + internal static class DecimalUtility + { + private const byte AsciiZero = (byte)'0'; + private const int AsciiDigitMaxIndex = '9' - AsciiZero; + private const byte AsciiMinus = (byte)'-'; + private const byte AsciiPlus = (byte)'+'; + private const byte AsciiUpperE = (byte)'E'; + private const byte AsciiLowerE = (byte)'e'; + private const byte AsciiPeriod = (byte)'.'; + private const byte AsciiSpace = (byte)' '; + + /// + /// Gets the BigInteger bytes for the given string value. + /// + /// The numeric string value to get bytes for. + /// The decimal precision for the target Decimal[128|256] + /// The decimal scale for the target Decimal[128|256] + /// The width in bytes for the target buffer. Should match the length of the bytes parameter. + /// The buffer to place the BigInteger bytes into. + /// + internal static void GetBytes(ReadOnlySpan value, int precision, int scale, int byteWidth, Span bytes) + { + if (precision < 1) + { + throw new ArgumentOutOfRangeException(nameof(precision), precision, "precision value must be greater than zero."); + } + if (scale < 0 || scale > precision) + { + throw new ArgumentOutOfRangeException(nameof(scale), scale, "scale value must be in the range 0 .. precision."); + } + if (byteWidth > bytes.Length) + { + throw new ArgumentOutOfRangeException(nameof(byteWidth), byteWidth, $"value for byteWidth {byteWidth} exceeds the the size of bytes."); + } + + BigInteger integerValue = ToBigInteger(value, precision, scale); + + FillBytes(bytes, integerValue, byteWidth); + } + + private static void FillBytes(Span bytes, BigInteger integerValue, int byteWidth) + { + int bytesWritten = 0; +#if NETCOREAPP + if (!integerValue.TryWriteBytes(bytes, out bytesWritten, false, !BitConverter.IsLittleEndian)) + { + throw new OverflowException("Could not extract bytes from integer value " + integerValue); + } +#else + byte[] tempBytes = integerValue.ToByteArray(); + bytesWritten = tempBytes.Length; + if (bytesWritten > byteWidth) + { + throw new OverflowException($"Decimal size greater than {byteWidth} bytes: {bytesWritten}"); + } + tempBytes.CopyTo(bytes); +#endif + byte fillByte = (byte)(integerValue < 0 ? 255 : 0); + for (int i = bytesWritten; i < byteWidth; i++) + { + bytes[i] = fillByte; + } + } + + private static BigInteger ToBigInteger(ReadOnlySpan value, int precision, int scale) + { + ReadOnlySpan significantValue = GetSignificantValue(value, precision, scale); +#if NETCOREAPP + // We can rely on the fact that all the characters in the span have already been confirmed to be ASCII (i.e., < 128) + Span chars = stackalloc char[significantValue.Length]; + Encoding.UTF8.GetChars(significantValue, chars); + return BigInteger.Parse(chars); +#else + return BigInteger.Parse(Encoding.UTF8.GetString(significantValue)); +#endif + } + + private static ReadOnlySpan GetSignificantValue(ReadOnlySpan value, int precision, int scale) + { + ParseDecimal(value, out ParserState state); + + ProcessDecimal(value, + precision, + scale, + state, + out byte sign, + out ReadOnlySpan integerSpan, + out ReadOnlySpan fractionalSpan, + out int neededScale); + + Span significant = new byte[precision + 1]; + BuildSignificantValue( + sign, + scale, + integerSpan, + fractionalSpan, + neededScale, + significant); + + return significant; + } + + private static void ProcessDecimal(ReadOnlySpan value, int precision, int scale, ParserState state, out byte sign, out ReadOnlySpan integerSpan, out ReadOnlySpan fractionalSpan, out int neededScale) + { + int int_length = 0; + int frac_length = 0; + int exponent = 0; + + if (state.IntegerStart != -1 && state.IntegerEnd != -1) int_length = state.IntegerEnd - state.IntegerStart + 1; + if (state.FractionalStart != -1 && state.FractionalEnd != -1) frac_length = state.FractionalEnd - state.FractionalStart + 1; + if (state.ExponentIndex != -1 && state.ExponentStart != -1 && state.ExponentEnd != -1 && state.ExponentEnd >= state.ExponentStart) + { + int expStart = state.ExpSignIndex != -1 ? state.ExpSignIndex : state.ExponentStart; + int expLength = state.ExponentEnd - expStart + 1; + ReadOnlySpan exponentSpan = value.Slice(expStart, expLength); + if (!Utf8Parser.TryParse(exponentSpan, out exponent, out int _)) + { + throw new FormatException($"unable to parse exponent value '{Encoding.UTF8.GetString(exponentSpan)}'"); + } + } + integerSpan = int_length > 0 ? value.Slice(state.IntegerStart, state.IntegerEnd - state.IntegerStart + 1) : []; + fractionalSpan = frac_length > 0 ? value.Slice(state.FractionalStart, state.FractionalEnd - state.FractionalStart + 1) : []; + Span tempSignificant; + if (exponent != 0) + { + tempSignificant = new byte[int_length + frac_length]; + if (int_length > 0) value.Slice(state.IntegerStart, state.IntegerEnd - state.IntegerStart + 1).CopyTo(tempSignificant.Slice(0)); + if (frac_length > 0) value.Slice(state.FractionalStart, state.FractionalEnd - state.FractionalStart + 1).CopyTo(tempSignificant.Slice(int_length)); + // Trim trailing zeros from combined string + while (tempSignificant[tempSignificant.Length - 1] == AsciiZero) + { + tempSignificant = tempSignificant.Slice(0, tempSignificant.Length - 1); + } + // Recalculate integer and fractional length + if (exponent > 0) + { + int_length = Math.Min(int_length + exponent, tempSignificant.Length); + frac_length = Math.Max(Math.Min(frac_length - exponent, tempSignificant.Length - int_length), 0); + } + else + { + int_length = Math.Max(int_length + exponent, 0); + frac_length = Math.Max(Math.Min(frac_length - exponent, tempSignificant.Length - int_length), 0); + } + // Reset the integer and fractional span + fractionalSpan = tempSignificant.Slice(int_length, frac_length); + integerSpan = tempSignificant.Slice(0, int_length); + // Trim leading zeros fron new integer span + while (integerSpan.Length > 0 && integerSpan[0] == AsciiZero) + { + integerSpan = integerSpan.Slice(1); + int_length -= 1; + } + } + + int neededPrecision = int_length + frac_length; + neededScale = frac_length; + if (neededPrecision > precision) + { + throw new OverflowException($"Decimal precision cannot be greater than that in the Arrow vector: {Encoding.UTF8.GetString(value)} has precision > {precision}"); + } + if (neededScale > scale) + { + throw new OverflowException($"Decimal scale cannot be greater than that in the Arrow vector: {Encoding.UTF8.GetString(value)} has scale > {scale}"); + } + sign = state.SignIndex != -1 ? value[state.SignIndex] : AsciiPlus; + } + + private static void BuildSignificantValue( + byte sign, + int scale, + ReadOnlySpan integerSpan, + ReadOnlySpan fractionalSpan, + int neededScale, + Span significant) + { + significant[0] = sign; + int end = 0; + integerSpan.CopyTo(significant.Slice(end + 1)); + end += integerSpan.Length; + fractionalSpan.CopyTo(significant.Slice(end + 1)); + end += fractionalSpan.Length; + + // Add trailing zeros to adjust for scale + while (neededScale < scale) + { + neededScale++; + end++; + significant[end] = AsciiZero; + } + } + + private enum ParseState + { + StartWhiteSpace, + SignOrDigitOrDecimal, + DigitOrDecimalOrExponent, + FractionOrExponent, + ExpSignOrExpValue, + ExpValue, + EndWhiteSpace, + Invalid, + } + + private struct ParserState + { + public ParseState CurrentState = ParseState.StartWhiteSpace; + public int SignIndex = -1; + public int IntegerStart = -1; + public int IntegerEnd = -1; + public int DecimalIndex = -1; + public int FractionalStart = -1; + public int FractionalEnd = -1; + public int ExponentIndex = -1; + public int ExpSignIndex = -1; + public int ExponentStart = -1; + public int ExponentEnd = -1; + public bool HasZero = false; + + public ParserState() { } + } + + private static void ParseDecimal(ReadOnlySpan value, out ParserState parserState) + { + ParserState state = new(); + int index = 0; + int length = value.Length; + while (index < length) + { + byte c = value[index]; + switch (state.CurrentState) + { + case ParseState.StartWhiteSpace: + if (c != AsciiSpace) + { + state.CurrentState = ParseState.SignOrDigitOrDecimal; + } + else + { + index++; + } + break; + case ParseState.SignOrDigitOrDecimal: + // Is Ascii Numeric + if ((uint)(c - AsciiZero) <= AsciiDigitMaxIndex) + { + if (!state.HasZero && c == AsciiZero) state.HasZero |= true; + state.IntegerStart = index; + state.IntegerEnd = index; + index++; + state.CurrentState = ParseState.DigitOrDecimalOrExponent; + } + else if (c == AsciiMinus || c == AsciiPlus) + { + state.SignIndex = index; + index++; + state.CurrentState = ParseState.DigitOrDecimalOrExponent; + } + else if (c == AsciiPeriod) + { + state.DecimalIndex = index; + index++; + state.CurrentState = ParseState.FractionOrExponent; + } + else if (c == AsciiSpace) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.DigitOrDecimalOrExponent: + // Is Ascii Numeric + if ((uint)(c - AsciiZero) <= AsciiDigitMaxIndex) + { + if (state.IntegerStart == -1) state.IntegerStart = index; + if (!state.HasZero && c == AsciiZero) state.HasZero |= true; + state.IntegerEnd = index; + index++; + } + else if (c == AsciiPeriod) + { + state.DecimalIndex = index; + index++; + state.CurrentState = ParseState.FractionOrExponent; + } + else if (c == AsciiUpperE || c == AsciiLowerE) + { + state.ExponentIndex = index; + index++; + state.CurrentState = ParseState.ExpSignOrExpValue; + } + else if (c == AsciiSpace) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.FractionOrExponent: + // Is Ascii Numeric + if ((uint)(c - AsciiZero) <= AsciiDigitMaxIndex) + { + if (state.FractionalStart == -1) state.FractionalStart = index; + if (!state.HasZero && c == AsciiZero) state.HasZero |= true; + state.FractionalEnd = index; + index++; + } + else if (c == AsciiUpperE || c == AsciiLowerE) + { + state.ExponentIndex = index; + index++; + state.CurrentState = ParseState.ExpSignOrExpValue; + } + else if (c == AsciiSpace) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.ExpSignOrExpValue: + // Is Ascii Numeric + if ((uint)(c - AsciiZero) <= AsciiDigitMaxIndex) + { + if (state.ExponentStart == -1) state.ExponentStart = index; + state.ExponentEnd = index; + index++; + state.CurrentState = ParseState.ExpValue; + } + else if (c == AsciiMinus || c == AsciiPlus) + { + state.ExpSignIndex = index; + index++; + state.CurrentState = ParseState.ExpValue; + } + else if (c == AsciiSpace) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.ExpValue: + // Is Ascii Numeric + if ((uint)(c - AsciiZero) <= AsciiDigitMaxIndex) + { + if (state.ExponentStart == -1) state.ExponentStart = index; + state.ExponentEnd = index; + index++; + } + else if (c == AsciiSpace) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.EndWhiteSpace: + if (c == AsciiSpace) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.Invalid: + throw new ArgumentOutOfRangeException(nameof(value), Encoding.UTF8.GetString(value), $"Invalid numeric value at index {index}."); + } + } + // Trim leading zeros from integer portion + if (state.IntegerStart != -1 && state.IntegerEnd != -1) + { + for (int i = state.IntegerStart; i <= state.IntegerEnd; i++) + { + if (value[i] != AsciiZero) break; + + state.IntegerStart = i + 1; + if (state.IntegerStart > state.IntegerEnd) + { + state.IntegerStart = -1; + state.IntegerEnd = -1; + break; + } + } + } + // Trim trailing zeros from fractional portion + if (state.FractionalStart != -1 && state.FractionalEnd != -1) + { + for (int i = state.FractionalEnd; i >= state.FractionalStart; i--) + { + if (value[i] != AsciiZero) break; + + state.FractionalEnd = i - 1; + if (state.FractionalStart > state.FractionalEnd) + { + state.FractionalStart = -1; + state.FractionalEnd = -1; + break; + } + } + } + // Must have a integer or fractional part. + if (state.IntegerStart == -1 && state.FractionalStart == -1) + { + if (!state.HasZero) + throw new ArgumentOutOfRangeException(nameof(value), Encoding.UTF8.GetString(value), "input does not contain a valid numeric value."); + else + { + state.IntegerStart = value.IndexOf(AsciiZero); + state.IntegerEnd = state.IntegerStart; + } + } + + parserState = state; + } + } + +#if !NETCOREAPP + internal static class EncodingExtensions + { + public static string GetString(this Encoding encoding, ReadOnlySpan source) + { + return encoding.GetString(source.ToArray()); + } + } +#endif +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index 57bca5d1d7..b603fdcb11 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -17,7 +17,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Ipc; @@ -27,52 +26,106 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { - public abstract class HiveServer2Connection : AdbcConnection + internal abstract class HiveServer2Connection : AdbcConnection { - const string userAgent = "AdbcExperimental/0.0"; - - protected TOperationHandle? operationHandle; - protected readonly IReadOnlyDictionary properties; - internal TTransport? transport; - internal TCLIService.Client? client; - internal TSessionHandle? sessionHandle; + internal const long BatchSizeDefault = 50000; + internal const int PollTimeMillisecondsDefault = 500; + private const int ConnectTimeoutMillisecondsDefault = 30000; + private TTransport? _transport; + private TCLIService.Client? _client; private readonly Lazy _vendorVersion; private readonly Lazy _vendorName; internal HiveServer2Connection(IReadOnlyDictionary properties) { - this.properties = properties; + Properties = properties; // Note: "LazyThreadSafetyMode.PublicationOnly" is thread-safe initialization where // the first successful thread sets the value. If an exception is thrown, initialization // will retry until it successfully returns a value without an exception. // https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects _vendorVersion = new Lazy(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), LazyThreadSafetyMode.PublicationOnly); _vendorName = new Lazy(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), LazyThreadSafetyMode.PublicationOnly); + + if (properties.TryGetValue(ApacheParameters.QueryTimeoutSeconds, out string? queryTimeoutSecondsSettingValue)) + { + if (ApacheUtility.QueryTimeoutIsValid(ApacheParameters.QueryTimeoutSeconds, queryTimeoutSecondsSettingValue, out int queryTimeoutSeconds)) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + } + } } internal TCLIService.Client Client { - get { return this.client ?? throw new InvalidOperationException("connection not open"); } + get { return _client ?? throw new InvalidOperationException("connection not open"); } } - protected string VendorVersion => _vendorVersion.Value; + internal string VendorVersion => _vendorVersion.Value; - protected string VendorName => _vendorName.Value; + internal string VendorName => _vendorName.Value; + + protected internal int QueryTimeoutSeconds { get; set; } = ApacheUtility.QueryTimeoutSecondsDefault; + + internal IReadOnlyDictionary Properties { get; } internal async Task OpenAsync() { - TProtocol protocol = await CreateProtocolAsync(); - this.transport = protocol.Transport; - this.client = new TCLIService.Client(protocol); - - var s0 = await this.client.OpenSession(CreateSessionRequest()); - this.sessionHandle = s0.SessionHandle; + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(ConnectTimeoutMilliseconds, ApacheUtility.TimeUnit.Milliseconds); + try + { + TTransport transport = CreateTransport(); + TProtocol protocol = await CreateProtocolAsync(transport, cancellationToken); + _transport = protocol.Transport; + _client = new TCLIService.Client(protocol); + TOpenSessionReq request = CreateSessionRequest(); + + TOpenSessionResp? session = await Client.OpenSession(request, cancellationToken); + + // Explicitly check the session status + if (session == null) + { + throw new HiveServer2Exception("Unable to open session. Unknown error."); + } + else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS) + { + throw new HiveServer2Exception(session.Status.ErrorMessage) + .SetNativeError(session.Status.ErrorCode) + .SetSqlState(session.Status.SqlState); + } + + SessionHandle = session.SessionHandle; + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The operation timed out while attempting to open a session. Please try increasing connect timeout.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + // Handle other exceptions if necessary + throw new HiveServer2Exception($"An unexpected error occurred while opening the session. '{ex.Message}'", ex); + } } - protected abstract ValueTask CreateProtocolAsync(); + internal TSessionHandle? SessionHandle { get; private set; } + + protected internal DataTypeConversion DataTypeConversion { get; set; } = DataTypeConversion.None; + + protected internal HiveServer2TlsOption TlsOptions { get; set; } = HiveServer2TlsOption.Empty; + + protected internal int ConnectTimeoutMilliseconds { get; set; } = ConnectTimeoutMillisecondsDefault; + + protected abstract TTransport CreateTransport(); + + protected abstract Task CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default); protected abstract TOpenSessionReq CreateSessionRequest(); + internal abstract SchemaParser SchemaParser { get; } + + internal abstract IArrowArrayStream NewReader(T statement, Schema schema) where T : HiveServer2Statement; + public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList? tableTypes, string? columnNamePattern) { throw new NotImplementedException(); @@ -83,55 +136,77 @@ public override IArrowArrayStream GetTableTypes() throw new NotImplementedException(); } - protected void PollForResponse() + internal static async Task PollForResponseAsync(TOperationHandle operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds, CancellationToken cancellationToken = default) { TGetOperationStatusResp? statusResponse = null; do { - if (statusResponse != null) { Thread.Sleep(500); } - TGetOperationStatusReq request = new TGetOperationStatusReq(this.operationHandle); - statusResponse = this.Client.GetOperationStatus(request).Result; + if (statusResponse != null) { await Task.Delay(pollTimeMilliseconds, cancellationToken); } + TGetOperationStatusReq request = new(operationHandle); + statusResponse = await client.GetOperationStatus(request, cancellationToken); } while (statusResponse.OperationState == TOperationState.PENDING_STATE || statusResponse.OperationState == TOperationState.RUNNING_STATE); + + // Must be in the finished state to be valid. If not, typically a server error or timeout has occurred. + if (statusResponse.OperationState != TOperationState.FINISHED_STATE) + { + throw new HiveServer2Exception(statusResponse.ErrorMessage, AdbcStatusCode.InvalidState) + .SetSqlState(statusResponse.SqlState) + .SetNativeError(statusResponse.ErrorCode); + } } private string GetInfoTypeStringValue(TGetInfoType infoType) { TGetInfoReq req = new() { - SessionHandle = this.sessionHandle ?? throw new InvalidOperationException("session not created"), + SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), InfoType = infoType, }; - TGetInfoResp getInfoResp = Client.GetInfo(req).Result; - if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try { - throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage) - .SetNativeError(getInfoResp.Status.ErrorCode) - .SetSqlState(getInfoResp.Status.SqlState); + TGetInfoResp getInfoResp = Client.GetInfo(req, cancellationToken).Result; + if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage) + .SetNativeError(getInfoResp.Status.ErrorCode) + .SetSqlState(getInfoResp.Status.SqlState); + } + + return getInfoResp.InfoValue.StringValue; + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); } - - return getInfoResp.InfoValue.StringValue; } public override void Dispose() { - if (this.client != null) + if (_client != null) { - TCloseSessionReq r6 = new TCloseSessionReq(this.sessionHandle); - this.client.CloseSession(r6).Wait(); - - this.transport?.Close(); - this.client.Dispose(); - this.transport = null; - this.client = null; + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + TCloseSessionReq r6 = new(SessionHandle); + _client.CloseSession(r6, cancellationToken).Wait(); + _transport?.Close(); + _client.Dispose(); + _transport = null; + _client = null; } } - protected Schema GetSchema() + internal static async Task GetResultSetMetadataAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) { - TGetResultSetMetadataReq request = new TGetResultSetMetadataReq(this.operationHandle); - TGetResultSetMetadataResp response = this.Client.GetResultSetMetadata(request).Result; - return SchemaParser.GetArrowSchema(response.Schema); + TGetResultSetMetadataReq request = new(operationHandle); + TGetResultSetMetadataResp response = await client.GetResultSetMetadata(request, cancellationToken); + return response; } } } diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs new file mode 100644 index 0000000000..4f2bc62d21 --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + public static class DataTypeConversionOptions + { + public const string None = "none"; + public const string Scalar = "scalar"; + } + + public static class TlsOptions + { + public const string AllowSelfSigned = "allow_self_signed"; + public const string AllowHostnameMismatch = "allow_hostname_mismatch"; + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs new file mode 100644 index 0000000000..34dbf10f2c --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs @@ -0,0 +1,384 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Buffers.Text; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Globalization; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Ipc; +using Apache.Arrow.Types; +using Apache.Hive.Service.Rpc.Thrift; +using Thrift.Transport; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + internal class HiveServer2Reader : IArrowArrayStream + { + private const byte AsciiZero = (byte)'0'; + private const int AsciiDigitMaxIndex = '9' - AsciiZero; + private const byte AsciiDash = (byte)'-'; + private const byte AsciiSpace = (byte)' '; + private const byte AsciiColon = (byte)':'; + private const byte AsciiPeriod = (byte)'.'; + private const char StandardFormatRoundTrippable = 'O'; + private const char StandardFormatExponent = 'E'; + private const int YearMonthSepIndex = 4; + private const int MonthDaySepIndex = 7; + private const int KnownFormatDateLength = 10; + private const int KnownFormatDateTimeLength = 19; + private const int DayHourSepIndex = 10; + private const int HourMinuteSepIndex = 13; + private const int MinuteSecondSepIndex = 16; + private const int YearIndex = 0; + private const int MonthIndex = 5; + private const int DayIndex = 8; + private const int HourIndex = 11; + private const int MinuteIndex = 14; + private const int SecondIndex = 17; + private const int SecondSubsecondSepIndex = 19; + private const int SubsecondIndex = 20; + private const int MillisecondDecimalPlaces = 3; + private HiveServer2Statement? _statement; + private readonly DataTypeConversion _dataTypeConversion; + private static readonly IReadOnlyDictionary> s_arrowStringConverters = + new Dictionary>() + { + { ArrowTypeId.Date32, ConvertToDate32 }, + { ArrowTypeId.Decimal128, ConvertToDecimal128 }, + { ArrowTypeId.Timestamp, ConvertToTimestamp }, + }; + private static readonly IReadOnlyDictionary> s_arrowDoubleConverters = + new Dictionary>() + { + { ArrowTypeId.Float, ConvertToFloat }, + }; + + public HiveServer2Reader( + HiveServer2Statement statement, + Schema schema, + DataTypeConversion dataTypeConversion) + { + _statement = statement; + Schema = schema; + _dataTypeConversion = dataTypeConversion; + } + + public Schema Schema { get; } + + public async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) + { + // All records have been exhausted + if (_statement == null) + { + return null; + } + + try + { + // Await the fetch response + TFetchResultsResp response = await FetchNext(_statement, cancellationToken); + + int columnCount = GetColumnCount(response); + int rowCount = GetRowCount(response, columnCount); + if ((_statement.BatchSize > 0 && rowCount < _statement.BatchSize) || rowCount == 0) + { + // This is the last batch + _statement = null; + } + + // Build the current batch, if any data exists + return rowCount > 0 ? CreateBatch(response, columnCount, rowCount) : null; + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); + } + } + + private RecordBatch CreateBatch(TFetchResultsResp response, int columnCount, int rowCount) + { + IList columnData = []; + bool shouldConvertScalar = _dataTypeConversion.HasFlag(DataTypeConversion.Scalar); + for (int i = 0; i < columnCount; i++) + { + IArrowType? expectedType = shouldConvertScalar ? Schema.FieldsList[i].DataType : null; + IArrowArray columnArray = GetArray(response.Results.Columns[i], expectedType); + columnData.Add(columnArray); + } + + return new RecordBatch(Schema, columnData, rowCount); + } + + private static int GetColumnCount(TFetchResultsResp response) => + response.Results.Columns.Count; + + private static int GetRowCount(TFetchResultsResp response, int columnCount) => + columnCount > 0 ? GetArray(response.Results.Columns[0]).Length : 0; + + private static async Task FetchNext(HiveServer2Statement statement, CancellationToken cancellationToken = default) + { + var request = new TFetchResultsReq(statement.OperationHandle, TFetchOrientation.FETCH_NEXT, statement.BatchSize); + return await statement.Connection.Client.FetchResults(request, cancellationToken); + } + + public void Dispose() + { + } + + private static IArrowArray GetArray(TColumn column, IArrowType? expectedArrowType = default) + { + IArrowArray arrowArray = + (IArrowArray?)column.BoolVal?.Values ?? + (IArrowArray?)column.ByteVal?.Values ?? + (IArrowArray?)column.I16Val?.Values ?? + (IArrowArray?)column.I32Val?.Values ?? + (IArrowArray?)column.I64Val?.Values ?? + (IArrowArray?)column.DoubleVal?.Values ?? + (IArrowArray?)column.StringVal?.Values ?? + (IArrowArray?)column.BinaryVal?.Values ?? + throw new InvalidOperationException("unsupported data type"); + if (expectedArrowType != null && arrowArray is StringArray stringArray && s_arrowStringConverters.ContainsKey(expectedArrowType.TypeId)) + { + // Perform a conversion from string to native/scalar type. + Func converter = s_arrowStringConverters[expectedArrowType.TypeId]; + return converter(stringArray, expectedArrowType); + } + else if (expectedArrowType != null && arrowArray is DoubleArray doubleArray && s_arrowDoubleConverters.ContainsKey(expectedArrowType.TypeId)) + { + // Perform a conversion from double to another (float) type. + Func converter = s_arrowDoubleConverters[expectedArrowType.TypeId]; + return converter(doubleArray, expectedArrowType); + } + return arrowArray; + } + + internal static Date32Array ConvertToDate32(StringArray array, IArrowType _) + { + const DateTimeStyles DateTimeStyles = DateTimeStyles.AllowWhiteSpaces; + int length = array.Length; + var resultArray = new Date32Array + .Builder() + .Reserve(length); + for (int i = 0; i < length; i++) + { + // Work with UTF8 string. + ReadOnlySpan date = array.GetBytes(i, out bool isNull); + if (isNull) + { + resultArray.AppendNull(); + } + else if (TryParse(date, out DateTime dateTime) + || Utf8Parser.TryParse(date, out dateTime, out int _, standardFormat: StandardFormatRoundTrippable) + || DateTime.TryParse(array.GetString(i), CultureInfo.InvariantCulture, DateTimeStyles, out dateTime)) + { + resultArray.Append(dateTime); + } + else + { + throw new FormatException($"unable to convert value '{array.GetString(i)}' to DateTime"); + } + } + + return resultArray.Build(); + } + + internal static FloatArray ConvertToFloat(DoubleArray array, IArrowType _) + { + int length = array.Length; + var resultArray = new FloatArray + .Builder() + .Reserve(length); + for (int i = 0; i < length; i++) + { + resultArray.Append((float?)array.GetValue(i)); + } + + return resultArray.Build(); + } + + internal static bool TryParse(ReadOnlySpan date, out DateTime dateTime) + { + if (date.Length == KnownFormatDateLength + && date[YearMonthSepIndex] == AsciiDash && date[MonthDaySepIndex] == AsciiDash + && Utf8Parser.TryParse(date.Slice(YearIndex, 4), out int year, out int bytesConsumed) && bytesConsumed == 4 + && Utf8Parser.TryParse(date.Slice(MonthIndex, 2), out int month, out bytesConsumed) && bytesConsumed == 2 + && Utf8Parser.TryParse(date.Slice(DayIndex, 2), out int day, out bytesConsumed) && bytesConsumed == 2) + { + try + { + dateTime = new(year, month, day); + return true; + } + catch (ArgumentOutOfRangeException) + { + dateTime = default; + return false; + } + } + + dateTime = default; + return false; + } + + private static Decimal128Array ConvertToDecimal128(StringArray array, IArrowType schemaType) + { + int length = array.Length; + // Using the schema type to get the precision and scale. + Decimal128Type decimalType = (Decimal128Type)schemaType; + var resultArray = new Decimal128Array + .Builder(decimalType) + .Reserve(length); + Span buffer = stackalloc byte[decimalType.ByteWidth]; + + for (int i = 0; i < length; i++) + { + // Work with UTF8 string. + ReadOnlySpan item = array.GetBytes(i, out bool isNull); + if (isNull) + { + resultArray.AppendNull(); + } + // Try to parse the value into a decimal because it is the most performant and handles the exponent syntax. But this might overflow. + else if (Utf8Parser.TryParse(item, out decimal decimalValue, out int _, standardFormat: StandardFormatExponent)) + { + resultArray.Append(new SqlDecimal(decimalValue)); + } + else + { + DecimalUtility.GetBytes(item, decimalType.Precision, decimalType.Scale, decimalType.ByteWidth, buffer); + resultArray.Append(buffer); + } + } + return resultArray.Build(); + } + + internal static TimestampArray ConvertToTimestamp(StringArray array, IArrowType _) + { + const DateTimeStyles DateTimeStyles = DateTimeStyles.AssumeUniversal | DateTimeStyles.AllowWhiteSpaces; + int length = array.Length; + // Match the precision of the server + var resultArrayBuilder = new TimestampArray + .Builder(TimeUnit.Microsecond) + .Reserve(length); + for (int i = 0; i < length; i++) + { + // Work with UTF8 string. + ReadOnlySpan date = array.GetBytes(i, out bool isNull); + if (isNull) + { + resultArrayBuilder.AppendNull(); + } + else if (TryParse(date, out DateTimeOffset dateValue) + || Utf8Parser.TryParse(date, out dateValue, out int _, standardFormat: StandardFormatRoundTrippable) + || DateTimeOffset.TryParse(array.GetString(i), CultureInfo.InvariantCulture, DateTimeStyles, out dateValue)) + { + resultArrayBuilder.Append(dateValue); + } + else + { + throw new FormatException($"unable to convert value '{array.GetString(i)}' to DateTimeOffset"); + } + } + + return resultArrayBuilder.Build(); + } + + internal static bool TryParse(ReadOnlySpan date, out DateTimeOffset dateValue) + { + bool isKnownFormat = date.Length >= KnownFormatDateTimeLength + && date[YearMonthSepIndex] == AsciiDash + && date[MonthDaySepIndex] == AsciiDash + && date[DayHourSepIndex] == AsciiSpace + && date[HourMinuteSepIndex] == AsciiColon + && date[MinuteSecondSepIndex] == AsciiColon; + + if (!isKnownFormat + || !Utf8Parser.TryParse(date.Slice(YearIndex, 4), out int year, out int bytesConsumed, standardFormat: 'D') || bytesConsumed != 4 + || !Utf8Parser.TryParse(date.Slice(MonthIndex, 2), out int month, out bytesConsumed, standardFormat: 'D') || bytesConsumed != 2 + || !Utf8Parser.TryParse(date.Slice(DayIndex, 2), out int day, out bytesConsumed, standardFormat: 'D') || bytesConsumed != 2 + || !Utf8Parser.TryParse(date.Slice(HourIndex, 2), out int hour, out bytesConsumed, standardFormat: 'D') || bytesConsumed != 2 + || !Utf8Parser.TryParse(date.Slice(MinuteIndex, 2), out int minute, out bytesConsumed, standardFormat: 'D') || bytesConsumed != 2 + || !Utf8Parser.TryParse(date.Slice(SecondIndex, 2), out int second, out bytesConsumed, standardFormat: 'D') || bytesConsumed != 2) + { + dateValue = default; + return false; + } + + try + { + dateValue = new(year, month, day, hour, minute, second, TimeSpan.Zero); + } + catch (ArgumentOutOfRangeException) + { + dateValue = default; + return false; + } + + // Retrieve subseconds, if available + int length = date.Length; + if (length > SecondSubsecondSepIndex) + { + if (date[SecondSubsecondSepIndex] == AsciiPeriod) + { + int start = -1; + int end = SubsecondIndex; + while (end < length && (uint)(date[end] - AsciiZero) <= AsciiDigitMaxIndex) + { + if (start == -1) start = end; + end++; + } + if (end < length) + { + // Indicates unrecognized trailing character(s) + dateValue = default; + return false; + } + + int subSecondsLength = start != -1 ? end - start : 0; + if (subSecondsLength > 0) + { + if (!Utf8Parser.TryParse(date.Slice(start, subSecondsLength), out int subSeconds, out _)) + { + dateValue = default; + return false; + } + + double factorOfMilliseconds = Math.Pow(10, subSecondsLength - MillisecondDecimalPlaces); + long ticks = (long)(subSeconds * (TimeSpan.TicksPerMillisecond / factorOfMilliseconds)); + dateValue = dateValue.AddTicks(ticks); + } + } + else + { + dateValue = default; + return false; + } + } + + return true; + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs new file mode 100644 index 0000000000..5ef4c01cbe --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs @@ -0,0 +1,58 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using Apache.Arrow.Types; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + internal class HiveServer2SchemaParser : SchemaParser + { + public override IArrowType GetArrowType(TPrimitiveTypeEntry thriftType, DataTypeConversion dataTypeConversion) + { + bool convertScalar = dataTypeConversion.HasFlag(DataTypeConversion.Scalar); + return thriftType.Type switch + { + TTypeId.BIGINT_TYPE => Int64Type.Default, + TTypeId.BINARY_TYPE => BinaryType.Default, + TTypeId.BOOLEAN_TYPE => BooleanType.Default, + TTypeId.DOUBLE_TYPE => DoubleType.Default, + TTypeId.FLOAT_TYPE => convertScalar ? FloatType.Default : DoubleType.Default, + TTypeId.INT_TYPE => Int32Type.Default, + TTypeId.SMALLINT_TYPE => Int16Type.Default, + TTypeId.TINYINT_TYPE => Int8Type.Default, + TTypeId.DATE_TYPE => convertScalar ? Date32Type.Default : StringType.Default, + TTypeId.DECIMAL_TYPE => convertScalar ? NewDecima128Type(thriftType) : StringType.Default, + TTypeId.TIMESTAMP_TYPE => convertScalar ? TimestampType.Default : StringType.Default, + TTypeId.CHAR_TYPE + or TTypeId.NULL_TYPE + or TTypeId.STRING_TYPE + or TTypeId.VARCHAR_TYPE + or TTypeId.INTERVAL_DAY_TIME_TYPE + or TTypeId.INTERVAL_YEAR_MONTH_TYPE + or TTypeId.ARRAY_TYPE + or TTypeId.MAP_TYPE + or TTypeId.STRUCT_TYPE + or TTypeId.UNION_TYPE + or TTypeId.USER_DEFINED_TYPE => StringType.Default, + TTypeId.TIMESTAMPLOCALTZ_TYPE => throw new NotImplementedException(), + _ => throw new NotImplementedException(), + }; + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 4879bdecb3..06723e324d 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -16,49 +16,106 @@ */ using System; +using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; +using Thrift.Transport; namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { - public abstract class HiveServer2Statement : AdbcStatement + internal abstract class HiveServer2Statement : AdbcStatement { - private const int PollTimeMillisecondsDefault = 500; - private const int BatchSizeDefault = 50000; - protected internal HiveServer2Connection connection; - protected internal TOperationHandle? operationHandle; - protected HiveServer2Statement(HiveServer2Connection connection) { - this.connection = connection; + Connection = connection; } protected virtual void SetStatementProperties(TExecuteStatementReq statement) { + statement.QueryTimeout = QueryTimeoutSeconds; } - protected abstract IArrowArrayStream NewReader(T statement, Schema schema) where T : HiveServer2Statement; + public override QueryResult ExecuteQuery() + { + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + return ExecuteQueryAsyncInternal(cancellationToken).Result; + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); + } + } - public override QueryResult ExecuteQuery() => ExecuteQueryAsync().AsTask().Result; + public override UpdateResult ExecuteUpdate() + { + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + return ExecuteUpdateAsyncInternal(cancellationToken).Result; + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); + } + } - public override UpdateResult ExecuteUpdate() => ExecuteUpdateAsync().Result; + private async Task ExecuteQueryAsyncInternal(CancellationToken cancellationToken = default) + { + // this could either: + // take QueryTimeoutSeconds * 3 + // OR + // take QueryTimeoutSeconds (but this could be restricting) + await ExecuteStatementAsync(cancellationToken); // --> get QueryTimeout + + await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to QueryTimeout + Schema schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken); // + get the result, up to QueryTimeout + + return new QueryResult(-1, Connection.NewReader(this, schema)); + } public override async ValueTask ExecuteQueryAsync() { - await ExecuteStatementAsync(); - await PollForResponseAsync(); - Schema schema = await GetSchemaAsync(); + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + return await ExecuteQueryAsyncInternal(cancellationToken); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); + } + } - // TODO: Ensure this is set dynamically based on server capabilities - return new QueryResult(-1, NewReader(this, schema)); + private async Task GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) + { + TGetResultSetMetadataResp response = await HiveServer2Connection.GetResultSetMetadataAsync(operationHandle, client, cancellationToken); + return Connection.SchemaParser.GetArrowSchema(response.Schema, Connection.DataTypeConversion); } - public override async Task ExecuteUpdateAsync() + public async Task ExecuteUpdateAsyncInternal(CancellationToken cancellationToken = default) { const string NumberOfAffectedRowsColumnName = "num_affected_rows"; - - QueryResult queryResult = await ExecuteQueryAsync(); + QueryResult queryResult = await ExecuteQueryAsyncInternal(cancellationToken); if (queryResult.Stream == null) { throw new AdbcException("no data found"); @@ -73,11 +130,13 @@ public override async Task ExecuteUpdateAsync() throw new AdbcException($"Unexpected data type for column: '{NumberOfAffectedRowsColumnName}'", new ArgumentException(NumberOfAffectedRowsColumnName)); } - // If no altered rows, i.e. DDC statements, then -1 is the default. + // The default is -1. + if (affectedRowsField == null) return new UpdateResult(-1); + long? affectedRows = null; while (true) { - using RecordBatch nextBatch = await stream.ReadNextRecordBatchAsync(); + using RecordBatch nextBatch = await stream.ReadNextRecordBatchAsync(cancellationToken); if (nextBatch == null) { break; } Int64Array numOfModifiedArray = (Int64Array)nextBatch.Column(NumberOfAffectedRowsColumnName); // Note: should only have one item, but iterate for completeness @@ -88,85 +147,95 @@ public override async Task ExecuteUpdateAsync() } } + // If no altered rows, i.e. DDC statements, then -1 is the default. return new UpdateResult(affectedRows ?? -1); } + public override async Task ExecuteUpdateAsync() + { + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + return await ExecuteUpdateAsyncInternal(cancellationToken); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); + } + } + public override void SetOption(string key, string value) { switch (key) { - case Options.PollTimeMilliseconds: + case ApacheParameters.PollTimeMilliseconds: UpdatePollTimeIfValid(key, value); break; - case Options.BatchSize: + case ApacheParameters.BatchSize: UpdateBatchSizeIfValid(key, value); break; + case ApacheParameters.QueryTimeoutSeconds: + if (ApacheUtility.QueryTimeoutIsValid(key, value, out int queryTimeoutSeconds)) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + } + break; default: throw AdbcException.NotImplemented($"Option '{key}' is not implemented."); } } - protected async Task ExecuteStatementAsync() + protected async Task ExecuteStatementAsync(CancellationToken cancellationToken = default) { - TExecuteStatementReq executeRequest = new TExecuteStatementReq(this.connection.sessionHandle, this.SqlQuery); + TExecuteStatementReq executeRequest = new TExecuteStatementReq(Connection.SessionHandle, SqlQuery); SetStatementProperties(executeRequest); - TExecuteStatementResp executeResponse = await this.connection.Client.ExecuteStatement(executeRequest); + TExecuteStatementResp executeResponse = await Connection.Client.ExecuteStatement(executeRequest, cancellationToken); if (executeResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) { throw new HiveServer2Exception(executeResponse.Status.ErrorMessage) .SetSqlState(executeResponse.Status.SqlState) .SetNativeError(executeResponse.Status.ErrorCode); } - this.operationHandle = executeResponse.OperationHandle; + OperationHandle = executeResponse.OperationHandle; } - protected async Task PollForResponseAsync() - { - TGetOperationStatusResp? statusResponse = null; - do - { - if (statusResponse != null) { await Task.Delay(PollTimeMilliseconds); } - TGetOperationStatusReq request = new TGetOperationStatusReq(this.operationHandle); - statusResponse = await this.connection.Client.GetOperationStatus(request); - } while (statusResponse.OperationState == TOperationState.PENDING_STATE || statusResponse.OperationState == TOperationState.RUNNING_STATE); - } + protected internal int PollTimeMilliseconds { get; private set; } = HiveServer2Connection.PollTimeMillisecondsDefault; + + protected internal long BatchSize { get; private set; } = HiveServer2Connection.BatchSizeDefault; - protected async ValueTask GetSchemaAsync() + protected internal int QueryTimeoutSeconds { - TGetResultSetMetadataReq request = new TGetResultSetMetadataReq(this.operationHandle); - TGetResultSetMetadataResp response = await this.connection.Client.GetResultSetMetadata(request); - return SchemaParser.GetArrowSchema(response.Schema); + // Coordinate updates with the connection + get => Connection.QueryTimeoutSeconds; + set => Connection.QueryTimeoutSeconds = value; } - protected internal int PollTimeMilliseconds { get; private set; } = PollTimeMillisecondsDefault; + public HiveServer2Connection Connection { get; private set; } - protected internal int BatchSize { get; private set; } = BatchSizeDefault; - - /// - /// Provides the constant string key values to the method. - /// - public class Options - { - // Options common to all HiveServer2Statement-derived drivers go here - public const string PollTimeMilliseconds = "adbc.statement.polltime_milliseconds"; - public const string BatchSize = "adbc.statement.batch_size"; - } + public TOperationHandle? OperationHandle { get; private set; } private void UpdatePollTimeIfValid(string key, string value) => PollTimeMilliseconds = !string.IsNullOrEmpty(key) && int.TryParse(value, result: out int pollTimeMilliseconds) && pollTimeMilliseconds >= 0 ? pollTimeMilliseconds - : throw new ArgumentException($"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than or equal to zero.", nameof(value)); + : throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than or equal to 0."); - private void UpdateBatchSizeIfValid(string key, string value) => BatchSize = !string.IsNullOrEmpty(value) && int.TryParse(value, out int batchSize) && batchSize > 0 + private void UpdateBatchSizeIfValid(string key, string value) => BatchSize = !string.IsNullOrEmpty(value) && long.TryParse(value, out long batchSize) && batchSize > 0 ? batchSize - : throw new ArgumentException($"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than zero.", nameof(value)); + : throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than zero."); public override void Dispose() { - if (this.operationHandle != null) + if (OperationHandle != null) { - TCloseOperationReq request = new TCloseOperationReq(this.operationHandle); - this.connection.Client.CloseOperation(request).Wait(); - this.operationHandle = null; + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + TCloseOperationReq request = new TCloseOperationReq(OperationHandle); + Connection.Client.CloseOperation(request, cancellationToken).Wait(); + OperationHandle = null; } base.Dispose(); diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2TlsOption.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2TlsOption.cs new file mode 100644 index 0000000000..84f56a485d --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2TlsOption.cs @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + [Flags] + internal enum HiveServer2TlsOption + { + Empty = 0, + AllowSelfSigned = 1, + AllowHostnameMismatch = 2, + } + + internal static class TlsOptionsParser + { + internal const string SupportedList = TlsOptions.AllowSelfSigned + "," + TlsOptions.AllowHostnameMismatch; + + internal static HiveServer2TlsOption Parse(string? tlsOptions) + { + HiveServer2TlsOption options = HiveServer2TlsOption.Empty; + if (tlsOptions == null) return options; + + string[] valueList = tlsOptions.Split(','); + foreach (string tlsOption in valueList) + { + options |= (tlsOption?.Trim().ToLowerInvariant()) switch + { + null or "" => HiveServer2TlsOption.Empty, + TlsOptions.AllowSelfSigned => HiveServer2TlsOption.AllowSelfSigned, + TlsOptions.AllowHostnameMismatch => HiveServer2TlsOption.AllowHostnameMismatch, + _ => throw new ArgumentOutOfRangeException(nameof(tlsOptions), tlsOption, "Invalid or unsupported TLS option"), + }; + } + return options; + } + } +} diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaAuthType.cs b/csharp/src/Drivers/Apache/Impala/ImpalaAuthType.cs new file mode 100644 index 0000000000..bbbea6f5a0 --- /dev/null +++ b/csharp/src/Drivers/Apache/Impala/ImpalaAuthType.cs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace Apache.Arrow.Adbc.Drivers.Apache.Impala +{ + internal enum ImpalaAuthType + { + Invalid = 0, + None, + UsernameOnly, + Basic, + Empty = int.MaxValue, + } + + internal static class AuthTypeOptionsParser + { + internal static bool TryParse(string? authType, out ImpalaAuthType authTypeValue) + { + switch (authType?.Trim().ToLowerInvariant()) + { + case null: + case "": + authTypeValue = ImpalaAuthType.Empty; + return true; + case AuthTypeOptions.None: + authTypeValue = ImpalaAuthType.None; + return true; + case AuthTypeOptions.UsernameOnly: + authTypeValue = ImpalaAuthType.UsernameOnly; + return true; + case AuthTypeOptions.Basic: + authTypeValue = ImpalaAuthType.Basic; + return true; + default: + authTypeValue = ImpalaAuthType.Invalid; + return false; + } + } + } +} diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs index e9e0018d15..0e673c7c4a 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs @@ -15,7 +15,9 @@ * limitations under the License. */ +using System; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; using Apache.Arrow.Ipc; @@ -26,26 +28,36 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { - public class ImpalaConnection : HiveServer2Connection + internal class ImpalaConnection : HiveServer2Connection { + // https://impala.apache.org/docs/build/html/topics/impala_ports.html + // https://impala.apache.org/docs/build/html/topics/impala_client.html + private const int DefaultSocketTransportPort = 21050; + private const int DefaultHttpTransportPort = 28000; + internal ImpalaConnection(IReadOnlyDictionary properties) : base(properties) { } - protected override ValueTask CreateProtocolAsync() + protected override TTransport CreateTransport() { - string hostName = properties["HostName"]; + string hostName = Properties["HostName"]; string? tmp; - int port = 21050; // default? - if (properties.TryGetValue("Port", out tmp)) + int port = DefaultSocketTransportPort; // default? + if (Properties.TryGetValue("Port", out tmp)) { port = int.Parse(tmp); } TConfiguration config = new TConfiguration(); TTransport transport = new ThriftSocketTransport(hostName, port, config); - return new ValueTask(new TBinaryProtocol(transport)); + return transport; + } + + protected override Task CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default) + { + return Task.FromResult(new TBinaryProtocol(transport)); } protected override TOpenSessionReq CreateSessionRequest() @@ -72,5 +84,9 @@ public override IArrowArrayStream GetTableTypes() } public override Schema GetTableSchema(string? catalog, string? dbSchema, string tableName) => throw new System.NotImplementedException(); + + internal override SchemaParser SchemaParser { get; } = new HiveServer2SchemaParser(); + + internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: DataTypeConversion); } } diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaParameters.cs b/csharp/src/Drivers/Apache/Impala/ImpalaParameters.cs new file mode 100644 index 0000000000..443e4180b1 --- /dev/null +++ b/csharp/src/Drivers/Apache/Impala/ImpalaParameters.cs @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace Apache.Arrow.Adbc.Drivers.Apache.Impala +{ + /// + /// Parameters used for connecting to Impala data sources. + /// + public static class ImpalaParameters + { + public const string HostName = "adbc.impala.host"; + public const string Port = "adbc.impala.port"; + public const string Path = "adbc.impala.path"; + public const string AuthType = "adbc.impala.auth_type"; + public const string DataTypeConv = "adbc.impala.data_type_conv"; + } + + public static class AuthTypeOptions + { + public const string None = "none"; + public const string UsernameOnly = "username_only"; + public const string Basic = "basic"; + } +} diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs index 17b15b9522..f94ac3970e 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs @@ -16,94 +16,23 @@ */ using System; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; -using Apache.Arrow.Ipc; -using Apache.Hive.Service.Rpc.Thrift; -using Thrift; -using Thrift.Protocol; -using Thrift.Transport.Client; namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { - public class ImpalaStatement : HiveServer2Statement + internal class ImpalaStatement : HiveServer2Statement { internal ImpalaStatement(ImpalaConnection connection) : base(connection) { } - public override object GetValue(IArrowArray arrowArray, int index) - { - throw new NotSupportedException(); - } - - protected override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema); - /// /// Provides the constant string key values to the method. /// - public new sealed class Options : HiveServer2Statement.Options + public sealed class Options : ApacheParameters { // options specific to Impala go here } - - class HiveServer2Reader : IArrowArrayStream - { - HiveServer2Statement? statement; - int counter; - - public HiveServer2Reader(HiveServer2Statement statement, Schema schema) - { - this.statement = statement; - this.Schema = schema; - } - - public Schema Schema { get; } - - public async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) - { - if (this.statement == null) - { - return null; - } - - TFetchResultsReq request = new TFetchResultsReq(this.statement.operationHandle, TFetchOrientation.FETCH_NEXT, 50000); - TFetchResultsResp response = await this.statement.connection.Client.FetchResults(request, cancellationToken); - - var buffer = new System.IO.MemoryStream(); - await response.WriteAsync(new TBinaryProtocol(new TStreamTransport(null, buffer, new TConfiguration())), cancellationToken); - System.IO.File.WriteAllBytes(string.Format("d:/src/buffer{0}.bin", this.counter++), buffer.ToArray()); - - RecordBatch result = new RecordBatch(this.Schema, response.Results.Columns.Select(GetArray), GetArray(response.Results.Columns[0]).Length); - - if (!response.HasMoreRows) - { - this.statement = null; - } - - return result; - } - - public void Dispose() - { - } - - static IArrowArray GetArray(TColumn column) - { - return - (IArrowArray?)column.BoolVal?.Values ?? - (IArrowArray?)column.ByteVal?.Values ?? - (IArrowArray?)column.I16Val?.Values ?? - (IArrowArray?)column.I32Val?.Values ?? - (IArrowArray?)column.I64Val?.Values ?? - (IArrowArray?)column.DoubleVal?.Values ?? - (IArrowArray?)column.StringVal?.Values ?? - (IArrowArray?)column.BinaryVal?.Values ?? - throw new InvalidOperationException("unsupported data type"); - } - } } } diff --git a/csharp/src/Drivers/Apache/Spark/README.md b/csharp/src/Drivers/Apache/Spark/README.md new file mode 100644 index 0000000000..3b5a0e79ed --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/README.md @@ -0,0 +1,140 @@ + + +# Spark Driver + +## Database and Connection Properties + +Properties should be passed in the call to `SparkDriver.Open`, +but can also be passed in the call to `AdbcDatabase.Connect`. + +| Property | Description | Default | +| :--- | :--- | :--- | +| `adbc.spark.type` | (Required) Indicates the Spark server type. One of `databricks`, `http` (future: `standard`) | | +| `adbc.spark.auth_type` | An indicator of the intended type of authentication. Allowed values: `none`, `username_only`, `basic`, and `token`. This property is optional. The authentication type can be inferred from `token`, `username`, and `password`. If a `token` value is provided, token authentication is used. Otherwise, if both `username` and `password` values are provided, basic authentication is used. | | +| `adbc.spark.host` | Host name for the data source. Do not include scheme or port number. Example: `sparkserver.region.cloudapp.azure.com` | | +| `adbc.spark.port` | The port number the data source listens on for a new connections. | `443` | +| `adbc.spark.path` | The URI path on the data source server. Example: `sql/protocolv1/o/0123456789123456/01234-0123456-source` | | +| `adbc.spark.token` | For token-based authentication, the token to be authenticated on the data source. Example: `abcdef0123456789` | | +| `uri` | The full URI that includes scheme, host, port and path. If set, this property takes precedence over `adbc.spark.host`, `adbc.spark.port` and `adbc.spark.path`. | | +| `username` | The user name used for basic authentication | | +| `password` | The password for the user name used for basic authentication. | | +| `adbc.spark.data_type_conv` | Comma-separated list of data conversion options. Each option indicates the type of conversion to perform on data returned from the Spark server.

Allowed values: `none`, `scalar`.

Option `none` indicates there is no conversion from Spark type to native type (i.e., no conversion from String to Timestamp for Apache Spark over HTTP). Example `adbc.spark.conv_data_type=none`.

Option `scalar` will perform conversion (if necessary) from the Spark data type to corresponding Arrow data types for types `DATE/Date32/DateTime`, `DECIMAL/Decimal128/SqlDecimal`, and `TIMESTAMP/Timestamp/DateTimeOffset`. Example `adbc.spark.conv_data_type=scalar` | `scalar` | +| `adbc.spark.tls_options` | Comma-separated list of TLS/SSL options. Each option indicates the TLS/SSL option when connecting to a Spark server.

Allowed values: `allow_self_signed`, `allow_hostname_mismatch`.

Option `allow_self_signed` allows certificate errors due to an unknown certificate authority, typically when using a self-signed certificate. Option `allow_hostname_mismatch` allow certificate errors due to a mismatch of the hostname. (e.g., when connecting through an SSH tunnel). Example `adbc.spark.tls_options=allow_self_signed` | | +| `adbc.spark.connect_timeout_ms` | Sets the timeout (in milliseconds) to open a new session. Values can be 0 (infinite) or greater than zero. | `30000` | +| `adbc.apache.statement.batch_size` | Sets the maximum number of rows to retrieve in a single batch request. | `50000` | +| `adbc.apache.statement.polltime_ms` | If polling is necessary to get a result, this option sets the length of time (in milliseconds) to wait between polls. | `500` | +| `adbc.apache.statement.query_timeout_s` | Sets the maximum time (in seconds) for a query to complete. Values can be 0 (infinite) or greater than zero. | `60` | + +## Timeout Configuration + +Timeouts have a hierarchy to their behavior. As specified above, the `adbc.spark.connect_timeout_ms` is analogous to a ConnectTimeout and used to initially establish a new session with the server. + +The `adbc.apache.statement.query_timeout_s` is analogous to a CommandTimeout for any subsequent calls to the server for requests, including metadata calls and executing queries. + +The `adbc.apache.statement.polltime_ms` specifies the time between polls to the service, up to the limit specifed by `adbc.apache.statement.query_timeout_s`. + +## Spark Types + +The following table depicts how the Spark ADBC driver converts a Spark type to an Arrow type and a .NET type: + +### Spark on Databricks + +| Spark Type | Arrow Type | C# Type | +| :--- | :---: | :---: | +| ARRAY* | String | string | +| BIGINT | Int64 | long | +| BINARY | Binary | byte[] | +| BOOLEAN | Boolean | bool | +| CHAR | String | string | +| DATE | Date32 | DateTime | +| DECIMAL | Decimal128 | SqlDecimal | +| DOUBLE | Double | double | +| FLOAT | Float | float | +| INT | Int32 | int | +| INTERVAL_DAY_TIME+ | String | string | +| INTERVAL_YEAR_MONTH+ | String | string | +| MAP* | String | string | +| NULL | Null | null | +| SMALLINT | Int16 | short | +| STRING | String | string | +| STRUCT* | String | string | +| TIMESTAMP | Timestamp | DateTimeOffset | +| TINYINT | Int8 | sbyte | +| UNION | String | string | +| USER_DEFINED | String | string | +| VARCHAR | String | string | + +### Apache Spark over HTTP (adbc.spark.data_type_conv = ?) + +| Spark Type | Arrow Type (`none`) | C# Type (`none`) | Arrow Type (`scalar`) | C# Type (`scalar`) | +| :--- | :---: | :---: | :---: | :---: | +| ARRAY* | String | string | | | +| BIGINT | Int64 | long | | | +| BINARY | Binary | byte[] | | | +| BOOLEAN | Boolean | bool | | | +| CHAR | String | string | | | +| DATE* | *String* | *string* | Date32 | DateTime | +| DECIMAL* | *String* | *string* | Decimal128 | SqlDecimal | +| DOUBLE | Double | double | | | +| FLOAT | *Double* | *double* | Float | float | +| INT | Int32 | int | | | +| INTERVAL_DAY_TIME+ | String | string | | | +| INTERVAL_YEAR_MONTH+ | String | string | | | +| MAP* | String | string | | | +| NULL | String | string | | | +| SMALLINT | Int16 | short | | | +| STRING | String | string | | | +| STRUCT* | String | string | | | +| TIMESTAMP* | *String* | *string* | Timestamp | DateTimeOffset | +| TINYINT | Int8 | sbyte | | | +| UNION | String | string | | | +| USER_DEFINED | String | string | | | +| VARCHAR | String | string | | | + +\* Types are returned as strings instead of "native" types
+\+ Interval types are returned as strings + +## Supported Variants + +### Spark on Databricks + +Support for Spark on Databricks is the most mature. + +The Spark ADBC driver supports token-based authentiation using the +[Databricks personal access token](https://docs.databricks.com/en/dev-tools/auth/pat.html). +Basic (username and password) authenication is not supported, at this time. + +### Apache Spark over HTPP + +Support for Spark over HTTP is initial. + +### Apache Spark Standard + +This is currently unsupported. + +### Azure Spark HDInsight + +To read data from Azure HDInsight Spark Cluster, use the following parameters: +adbc.spark.type = "http" +adbc.spark.port = "443" +adbc.spark.path = "/sparkhive2" +adbc.spark.host = $"{clusterHostName}" +username = $"{clusterUserName}" +password = $"{clusterPassword}" diff --git a/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs b/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs new file mode 100644 index 0000000000..8afb81c1e0 --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs @@ -0,0 +1,58 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal enum SparkAuthType + { + Invalid = 0, + None, + UsernameOnly, + Basic, + Token, + Empty = int.MaxValue, + } + + internal static class AuthTypeParser + { + internal static bool TryParse(string? authType, out SparkAuthType authTypeValue) + { + switch (authType?.Trim().ToLowerInvariant()) + { + case null: + case "": + authTypeValue = SparkAuthType.Empty; + return true; + case SparkAuthTypeConstants.None: + authTypeValue = SparkAuthType.None; + return true; + case SparkAuthTypeConstants.UsernameOnly: + authTypeValue = SparkAuthType.UsernameOnly; + return true; + case SparkAuthTypeConstants.Basic: + authTypeValue = SparkAuthType.Basic; + return true; + case SparkAuthTypeConstants.Token: + authTypeValue = SparkAuthType.Token; + return true; + default: + authTypeValue = SparkAuthType.Invalid; + return false; + } + } + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index 798fdeec29..b3c0c56ba1 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -19,8 +19,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; -using System.Net.Http; -using System.Net.Http.Headers; using System.Reflection; using System.Text; using System.Text.RegularExpressions; @@ -32,14 +30,13 @@ using Apache.Arrow.Ipc; using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; -using Thrift; -using Thrift.Protocol; +using Thrift.Transport; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { - public class SparkConnection : HiveServer2Connection + internal abstract class SparkConnection : HiveServer2Connection { - private readonly string UserAgent = $"{InfoDriverName.Replace(" ", "")}/{ProductVersionDefault}"; + internal static readonly string s_userAgent = $"{InfoDriverName.Replace(" ", "")}/{ProductVersionDefault}"; readonly AdbcInfoCode[] infoSupportedCodes = new[] { AdbcInfoCode.DriverName, @@ -255,54 +252,18 @@ internal enum ColumnTypeId internal SparkConnection(IReadOnlyDictionary properties) : base(properties) { + ValidateProperties(); _productVersion = new Lazy(() => GetProductVersion(), LazyThreadSafetyMode.PublicationOnly); } - protected string ProductVersion => _productVersion.Value; - - protected override async ValueTask CreateProtocolAsync() + private void ValidateProperties() { - Trace.TraceError($"create protocol with {properties.Count} properties."); - - foreach (var property in properties.Keys) - { - Trace.TraceError($"key = {property} value = {properties[property]}"); - } - - string hostName = properties[SparkParameters.HostName]; - string path = properties[SparkParameters.Path]; - string token; - - if (properties.ContainsKey(SparkParameters.Token)) - token = properties[SparkParameters.Token]; - else - token = properties[SparkParameters.Password]; - - HttpClient httpClient = new HttpClient(); - httpClient.BaseAddress = new UriBuilder(Uri.UriSchemeHttps, hostName, -1, path).Uri; - httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token); - httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(UserAgent); - httpClient.DefaultRequestHeaders.AcceptEncoding.Clear(); - httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity")); - httpClient.DefaultRequestHeaders.ExpectContinue = false; - - TConfiguration config = new TConfiguration(); - - ThriftHttpTransport transport = new ThriftHttpTransport(httpClient, config); - // can switch to the one below if want to use the experimental one with IPeekableTransport - // ThriftHttpTransport transport = new ThriftHttpTransport(httpClient, config); - await transport.OpenAsync(CancellationToken.None); - return new TBinaryProtocol(transport); + ValidateAuthentication(); + ValidateConnection(); + ValidateOptions(); } - protected override TOpenSessionReq CreateSessionRequest() - { - return new TOpenSessionReq(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7) - { - CanUseMultipleCatalogs = true, - Configuration = timestampConfig, - }; - } + protected string ProductVersion => _productVersion.Value; public override AdbcStatement CreateStatement() { @@ -455,244 +416,297 @@ public override IArrowArrayStream GetTableTypes() { TGetTableTypesReq req = new() { - SessionHandle = this.sessionHandle ?? throw new InvalidOperationException("session not created"), + SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), GetDirectResults = sparkGetDirectResults }; - TGetTableTypesResp resp = this.Client.GetTableTypes(req).Result; - if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS) + + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try { - throw new HiveServer2Exception(resp.Status.ErrorMessage) - .SetNativeError(resp.Status.ErrorCode) - .SetSqlState(resp.Status.SqlState); - } + TGetTableTypesResp resp = Client.GetTableTypes(req, cancellationToken).Result; + + if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(resp.Status.ErrorMessage) + .SetNativeError(resp.Status.ErrorCode) + .SetSqlState(resp.Status.SqlState); + } - List columns = resp.DirectResults.ResultSet.Results.Columns; - StringArray tableTypes = columns[0].StringVal.Values; + TRowSet rowSet = GetRowSetAsync(resp, cancellationToken).Result; + StringArray tableTypes = rowSet.Columns[0].StringVal.Values; - StringArray.Builder tableTypesBuilder = new StringArray.Builder(); - tableTypesBuilder.AppendRange(tableTypes); + StringArray.Builder tableTypesBuilder = new StringArray.Builder(); + tableTypesBuilder.AppendRange(tableTypes); - IArrowArray[] dataArrays = new IArrowArray[] - { + IArrowArray[] dataArrays = new IArrowArray[] + { tableTypesBuilder.Build() - }; + }; - return new SparkInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays); + return new SparkInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); + } } public override Schema GetTableSchema(string? catalog, string? dbSchema, string? tableName) { - TGetColumnsReq getColumnsReq = new TGetColumnsReq(this.sessionHandle); + TGetColumnsReq getColumnsReq = new TGetColumnsReq(SessionHandle); getColumnsReq.CatalogName = catalog; getColumnsReq.SchemaName = dbSchema; getColumnsReq.TableName = tableName; getColumnsReq.GetDirectResults = sparkGetDirectResults; - var columnsResponse = this.Client.GetColumns(getColumnsReq).Result; - if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try { - throw new Exception(columnsResponse.Status.ErrorMessage); - } + var columnsResponse = Client.GetColumns(getColumnsReq, cancellationToken).Result; + if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(columnsResponse.Status.ErrorMessage); + } - var result = columnsResponse.DirectResults; - var resultSchema = result.ResultSetMetadata.ArrowSchema; - var columns = result.ResultSet.Results.Columns; - var rowCount = columns[3].StringVal.Values.Length; + TRowSet rowSet = GetRowSetAsync(columnsResponse, cancellationToken).Result; + List columns = rowSet.Columns; + int rowCount = rowSet.Columns[3].StringVal.Values.Length; - Field[] fields = new Field[rowCount]; - for (int i = 0; i < rowCount; i++) + Field[] fields = new Field[rowCount]; + for (int i = 0; i < rowCount; i++) + { + string columnName = columns[3].StringVal.Values.GetString(i); + int? columnType = columns[4].I32Val.Values.GetValue(i); + string typeName = columns[5].StringVal.Values.GetString(i); + // Note: the following two columns do not seem to be set correctly for DECIMAL types. + //int? columnSize = columns[6].I32Val.Values.GetValue(i); + //int? decimalDigits = columns[8].I32Val.Values.GetValue(i); + bool nullable = columns[10].I32Val.Values.GetValue(i) == 1; + IArrowType dataType = SparkConnection.GetArrowType(columnType!.Value, typeName); + fields[i] = new Field(columnName, dataType, nullable); + } + return new Schema(fields, null); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) { - string columnName = columns[3].StringVal.Values.GetString(i); - int? columnType = columns[4].I32Val.Values.GetValue(i); - string typeName = columns[5].StringVal.Values.GetString(i); - // Note: the following two columns do not seem to be set correctly for DECIMAL types. - //int? columnSize = columns[6].I32Val.Values.GetValue(i); - //int? decimalDigits = columns[8].I32Val.Values.GetValue(i); - bool nullable = columns[10].I32Val.Values.GetValue(i) == 1; - IArrowType dataType = SparkConnection.GetArrowType(columnType!.Value, typeName); - fields[i] = new Field(columnName, dataType, nullable); + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); } - return new Schema(fields, null); } public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList? tableTypes, string? columnNamePattern) { - Trace.TraceError($"getting objects with depth={depth.ToString()}, catalog = {catalogPattern}, dbschema = {dbSchemaPattern}, tablename = {tableNamePattern}"); - Dictionary>> catalogMap = new Dictionary>>(); - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try { - TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(this.sessionHandle); - getCatalogsReq.GetDirectResults = sparkGetDirectResults; - - TGetCatalogsResp getCatalogsResp = this.Client.GetCatalogs(getCatalogsReq).Result; - if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) { - throw new Exception(getCatalogsResp.Status.ErrorMessage); - } - IReadOnlyDictionary columnMap = GetColumnIndexMap(getCatalogsResp.DirectResults.ResultSetMetadata.Schema.Columns); + TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); + getCatalogsReq.GetDirectResults = sparkGetDirectResults; - string catalogRegexp = PatternToRegEx(catalogPattern); - TRowSet resp = getCatalogsResp.DirectResults.ResultSet.Results; - IReadOnlyList list = resp.Columns[columnMap[TableCat]].StringVal.Values; - for (int i = 0; i < list.Count; i++) - { - string col = list[i]; - string catalog = col; + TGetCatalogsResp getCatalogsResp = Client.GetCatalogs(getCatalogsReq, cancellationToken).Result; - if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) + if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) { - catalogMap.Add(catalog, new Dictionary>()); + throw new Exception(getCatalogsResp.Status.ErrorMessage); } - } - } + var catalogsMetadata = GetResultSetMetadataAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(catalogsMetadata.Schema.Columns); - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) - { - TGetSchemasReq getSchemasReq = new TGetSchemasReq(this.sessionHandle); - getSchemasReq.CatalogName = catalogPattern; - getSchemasReq.SchemaName = dbSchemaPattern; - getSchemasReq.GetDirectResults = sparkGetDirectResults; + string catalogRegexp = PatternToRegEx(catalogPattern); + TRowSet rowSet = GetRowSetAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyList list = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + for (int i = 0; i < list.Count; i++) + { + string col = list[i]; + string catalog = col; - TGetSchemasResp getSchemasResp = this.Client.GetSchemas(getSchemasReq).Result; - if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(getSchemasResp.Status.ErrorMessage); + if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) + { + catalogMap.Add(catalog, new Dictionary>()); + } + } + // Handle the case where server does not support 'catalog' in the namespace. + if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern)) + { + catalogMap.Add(string.Empty, []); + } } - IReadOnlyDictionary columnMap = GetColumnIndexMap(getSchemasResp.DirectResults.ResultSetMetadata.Schema.Columns); - TRowSet resp = getSchemasResp.DirectResults.ResultSet.Results; - IReadOnlyList catalogList = resp.Columns[columnMap[TableCatalog]].StringVal.Values; - IReadOnlyList schemaList = resp.Columns[columnMap[TableSchem]].StringVal.Values; - - for (int i = 0; i < catalogList.Count; i++) + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) { - string catalog = catalogList[i]; - string schemaDb = schemaList[i]; - // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). - catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary()); - } - } + TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); + getSchemasReq.CatalogName = catalogPattern; + getSchemasReq.SchemaName = dbSchemaPattern; + getSchemasReq.GetDirectResults = sparkGetDirectResults; - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) - { - TGetTablesReq getTablesReq = new TGetTablesReq(this.sessionHandle); - getTablesReq.CatalogName = catalogPattern; - getTablesReq.SchemaName = dbSchemaPattern; - getTablesReq.TableName = tableNamePattern; - getTablesReq.GetDirectResults = sparkGetDirectResults; - - TGetTablesResp getTablesResp = this.Client.GetTables(getTablesReq).Result; - if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(getTablesResp.Status.ErrorMessage); - } + TGetSchemasResp getSchemasResp = Client.GetSchemas(getSchemasReq, cancellationToken).Result; + if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getSchemasResp.Status.ErrorMessage); + } - IReadOnlyDictionary columnMap = GetColumnIndexMap(getTablesResp.DirectResults.ResultSetMetadata.Schema.Columns); - TRowSet resp = getTablesResp.DirectResults.ResultSet.Results; + TGetResultSetMetadataResp schemaMetadata = GetResultSetMetadataAsync(getSchemasResp, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(schemaMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getSchemasResp, cancellationToken).Result; - IReadOnlyList catalogList = resp.Columns[columnMap[TableCat]].StringVal.Values; - IReadOnlyList schemaList = resp.Columns[columnMap[TableSchem]].StringVal.Values; - IReadOnlyList tableList = resp.Columns[columnMap[TableName]].StringVal.Values; - IReadOnlyList tableTypeList = resp.Columns[columnMap[TableType]].StringVal.Values; + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCatalog]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; - for (int i = 0; i < catalogList.Count; i++) - { - string catalog = catalogList[i]; - string schemaDb = schemaList[i]; - string tableName = tableList[i]; - string tableType = tableTypeList[i]; - TableInfo tableInfo = new(tableType); - catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). + catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary()); + } } - } - if (depth == GetObjectsDepth.All) - { - TGetColumnsReq columnsReq = new TGetColumnsReq(this.sessionHandle); - columnsReq.CatalogName = catalogPattern; - columnsReq.SchemaName = dbSchemaPattern; - columnsReq.TableName = tableNamePattern; - columnsReq.GetDirectResults = sparkGetDirectResults; + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) + { + TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle); + getTablesReq.CatalogName = catalogPattern; + getTablesReq.SchemaName = dbSchemaPattern; + getTablesReq.TableName = tableNamePattern; + getTablesReq.GetDirectResults = sparkGetDirectResults; + + TGetTablesResp getTablesResp = Client.GetTables(getTablesReq, cancellationToken).Result; + if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getTablesResp.Status.ErrorMessage); + } - if (!string.IsNullOrEmpty(columnNamePattern)) - columnsReq.ColumnName = columnNamePattern; + TGetResultSetMetadataResp tableMetadata = GetResultSetMetadataAsync(getTablesResp, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(tableMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getTablesResp, cancellationToken).Result; - var columnsResponse = this.Client.GetColumns(columnsReq).Result; - if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(columnsResponse.Status.ErrorMessage); + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; + IReadOnlyList tableTypeList = rowSet.Columns[columnMap[TableType]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string tableType = tableTypeList[i]; + TableInfo tableInfo = new(tableType); + catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); + } } - IReadOnlyDictionary columnMap = GetColumnIndexMap(columnsResponse.DirectResults.ResultSetMetadata.Schema.Columns); - TRowSet resp = columnsResponse.DirectResults.ResultSet.Results; - - IReadOnlyList catalogList = resp.Columns[columnMap[TableCat]].StringVal.Values; - IReadOnlyList schemaList = resp.Columns[columnMap[TableSchem]].StringVal.Values; - IReadOnlyList tableList = resp.Columns[columnMap[TableName]].StringVal.Values; - IReadOnlyList columnNameList = resp.Columns[columnMap[ColumnName]].StringVal.Values; - ReadOnlySpan columnTypeList = resp.Columns[columnMap[DataType]].I32Val.Values.Values; - IReadOnlyList typeNameList = resp.Columns[columnMap[TypeName]].StringVal.Values; - ReadOnlySpan nullableList = resp.Columns[columnMap[Nullable]].I32Val.Values.Values; - IReadOnlyList columnDefaultList = resp.Columns[columnMap[ColumnDef]].StringVal.Values; - ReadOnlySpan ordinalPosList = resp.Columns[columnMap[OrdinalPosition]].I32Val.Values.Values; - IReadOnlyList isNullableList = resp.Columns[columnMap[IsNullable]].StringVal.Values; - IReadOnlyList isAutoIncrementList = resp.Columns[columnMap[IsAutoIncrement]].StringVal.Values; - - for (int i = 0; i < catalogList.Count; i++) + if (depth == GetObjectsDepth.All) { - string catalog = catalogList[i]; - string schemaDb = schemaList[i]; - string tableName = tableList[i]; - string columnName = columnNameList[i]; - short colType = (short)columnTypeList[i]; - string typeName = typeNameList[i]; - short nullable = (short)nullableList[i]; - string? isAutoIncrementString = isAutoIncrementList[i]; - bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); - string isNullable = isNullableList[i] ?? "YES"; - string columnDefault = columnDefaultList[i] ?? ""; - // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed - int ordinalPos = ordinalPosList[i] + 1; - TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); - tableInfo?.ColumnName.Add(columnName); - tableInfo?.ColType.Add(colType); - tableInfo?.Nullable.Add(nullable); - tableInfo?.IsAutoIncrement.Add(isAutoIncrement); - tableInfo?.IsNullable.Add(isNullable); - tableInfo?.ColumnDefault.Add(columnDefault); - tableInfo?.OrdinalPosition.Add(ordinalPos); - SetPrecisionScaleAndTypeName(colType, typeName, tableInfo); - } - } + TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle); + columnsReq.CatalogName = catalogPattern; + columnsReq.SchemaName = dbSchemaPattern; + columnsReq.TableName = tableNamePattern; + columnsReq.GetDirectResults = sparkGetDirectResults; - StringArray.Builder catalogNameBuilder = new StringArray.Builder(); - List catalogDbSchemasValues = new List(); + if (!string.IsNullOrEmpty(columnNamePattern)) + columnsReq.ColumnName = columnNamePattern; - foreach (KeyValuePair>> catalogEntry in catalogMap) - { - catalogNameBuilder.Append(catalogEntry.Key); + var columnsResponse = Client.GetColumns(columnsReq, cancellationToken).Result; + if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(columnsResponse.Status.ErrorMessage); + } - if (depth == GetObjectsDepth.Catalogs) - { - catalogDbSchemasValues.Add(null); + TGetResultSetMetadataResp columnsMetadata = GetResultSetMetadataAsync(columnsResponse, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(columnsMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(columnsResponse, cancellationToken).Result; + + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; + IReadOnlyList columnNameList = rowSet.Columns[columnMap[ColumnName]].StringVal.Values; + ReadOnlySpan columnTypeList = rowSet.Columns[columnMap[DataType]].I32Val.Values.Values; + IReadOnlyList typeNameList = rowSet.Columns[columnMap[TypeName]].StringVal.Values; + ReadOnlySpan nullableList = rowSet.Columns[columnMap[Nullable]].I32Val.Values.Values; + IReadOnlyList columnDefaultList = rowSet.Columns[columnMap[ColumnDef]].StringVal.Values; + ReadOnlySpan ordinalPosList = rowSet.Columns[columnMap[OrdinalPosition]].I32Val.Values.Values; + IReadOnlyList isNullableList = rowSet.Columns[columnMap[IsNullable]].StringVal.Values; + IReadOnlyList isAutoIncrementList = rowSet.Columns[columnMap[IsAutoIncrement]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + // For systems that don't support 'catalog' in the namespace + string catalog = catalogList[i] ?? string.Empty; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string columnName = columnNameList[i]; + short colType = (short)columnTypeList[i]; + string typeName = typeNameList[i]; + short nullable = (short)nullableList[i]; + string? isAutoIncrementString = isAutoIncrementList[i]; + bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); + string isNullable = isNullableList[i] ?? "YES"; + string columnDefault = columnDefaultList[i] ?? ""; + // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed + int ordinalPos = ordinalPosList[i] + 1; + TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); + tableInfo?.ColumnName.Add(columnName); + tableInfo?.ColType.Add(colType); + tableInfo?.Nullable.Add(nullable); + tableInfo?.IsAutoIncrement.Add(isAutoIncrement); + tableInfo?.IsNullable.Add(isNullable); + tableInfo?.ColumnDefault.Add(columnDefault); + tableInfo?.OrdinalPosition.Add(ordinalPos); + SetPrecisionScaleAndTypeName(colType, typeName, tableInfo); + } } - else + + StringArray.Builder catalogNameBuilder = new StringArray.Builder(); + List catalogDbSchemasValues = new List(); + + foreach (KeyValuePair>> catalogEntry in catalogMap) { - catalogDbSchemasValues.Add(GetDbSchemas( - depth, catalogEntry.Value)); + catalogNameBuilder.Append(catalogEntry.Key); + + if (depth == GetObjectsDepth.Catalogs) + { + catalogDbSchemasValues.Add(null); + } + else + { + catalogDbSchemasValues.Add(GetDbSchemas( + depth, catalogEntry.Value)); + } } - } - Schema schema = StandardSchemas.GetObjectsSchema; - IReadOnlyList dataArrays = schema.Validate( - new List - { + Schema schema = StandardSchemas.GetObjectsSchema; + IReadOnlyList dataArrays = schema.Validate( + new List + { catalogNameBuilder.Build(), catalogDbSchemasValues.BuildListArrayForType(new StructType(StandardSchemas.DbSchemaSchema)), - }); + }); - return new SparkInfoArrowStream(schema, dataArrays); + return new SparkInfoArrowStream(schema, dataArrays); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); + } } private static IReadOnlyDictionary GetColumnIndexMap(List columns) => columns @@ -708,7 +722,7 @@ private static void SetPrecisionScaleAndTypeName(short colType, string typeName, case (short)ColumnTypeId.DECIMAL: case (short)ColumnTypeId.NUMERIC: { - SqlDecimalParserResult result = new SqlDecimalTypeParser().ParseOrDefault(typeName, new SqlDecimalParserResult(typeName)); + SqlDecimalParserResult result = SqlTypeNameParser.Parse(typeName, colType); tableInfo?.Precision.Add(result.Precision); tableInfo?.Scale.Add((short)result.Scale); tableInfo?.BaseTypeName.Add(result.BaseTypeName); @@ -717,30 +731,26 @@ private static void SetPrecisionScaleAndTypeName(short colType, string typeName, case (short)ColumnTypeId.CHAR: case (short)ColumnTypeId.NCHAR: - { - bool success = new SqlCharTypeParser().TryParse(typeName, out SqlCharVarcharParserResult? result); - tableInfo?.Precision.Add(success ? result!.ColumnSize : SqlVarcharTypeParser.VarcharColumnSizeDefault); - tableInfo?.Scale.Add(null); - tableInfo?.BaseTypeName.Add(success ? result!.BaseTypeName : "CHAR"); - break; - } case (short)ColumnTypeId.VARCHAR: case (short)ColumnTypeId.LONGVARCHAR: case (short)ColumnTypeId.LONGNVARCHAR: case (short)ColumnTypeId.NVARCHAR: { - bool success = new SqlVarcharTypeParser().TryParse(typeName, out SqlCharVarcharParserResult? result); - tableInfo?.Precision.Add(success ? result!.ColumnSize : SqlVarcharTypeParser.VarcharColumnSizeDefault); + SqlCharVarcharParserResult result = SqlTypeNameParser.Parse(typeName, colType); + tableInfo?.Precision.Add(result.ColumnSize); tableInfo?.Scale.Add(null); - tableInfo?.BaseTypeName.Add(success ? result!.BaseTypeName : "STRING"); + tableInfo?.BaseTypeName.Add(result.BaseTypeName); break; } default: - tableInfo?.Precision.Add(null); - tableInfo?.Scale.Add(null); - tableInfo?.BaseTypeName.Add(typeName); - break; + { + SqlTypeNameParserResult result = SqlTypeNameParser.Parse(typeName, colType); + tableInfo?.Precision.Add(null); + tableInfo?.Scale.Add(null); + tableInfo?.BaseTypeName.Add(result.BaseTypeName); + break; + } } } @@ -783,8 +793,8 @@ private static IArrowType GetArrowType(int columnTypeId, string typeName) case (int)ColumnTypeId.NUMERIC: // Note: parsing the type name for SQL DECIMAL types as the precision and scale values // are not returned in the Thrift call to GetColumns - return new SqlDecimalTypeParser() - .ParseOrDefault(typeName, new SqlDecimalParserResult(typeName)) + return SqlTypeNameParser + .Parse(typeName, columnTypeId) .Decimal128Type; case (int)ColumnTypeId.NULL: return NullType.Default; @@ -797,7 +807,7 @@ private static IArrowType GetArrowType(int columnTypeId, string typeName) } } - private StructArray GetDbSchemas( + private static StructArray GetDbSchemas( GetObjectsDepth depth, Dictionary> schemaMap) { @@ -841,7 +851,7 @@ private StructArray GetDbSchemas( nullBitmapBuffer.Build()); } - private StructArray GetTableSchemas( + private static StructArray GetTableSchemas( GetObjectsDepth depth, Dictionary tableMap) { @@ -892,7 +902,7 @@ private StructArray GetTableSchemas( nullBitmapBuffer.Build()); } - private StructArray GetColumnSchema(TableInfo tableInfo) + private static StructArray GetColumnSchema(TableInfo tableInfo) { StringArray.Builder columnNameBuilder = new StringArray.Builder(); Int32Array.Builder ordinalPositionBuilder = new Int32Array.Builder(); @@ -976,7 +986,7 @@ private StructArray GetColumnSchema(TableInfo tableInfo) nullBitmapBuffer.Build()); } - private string PatternToRegEx(string? pattern) + private static string PatternToRegEx(string? pattern) { if (pattern == null) return ".*"; @@ -984,70 +994,116 @@ private string PatternToRegEx(string? pattern) StringBuilder builder = new StringBuilder("(?i)^"); string convertedPattern = pattern.Replace("_", ".").Replace("%", ".*"); builder.Append(convertedPattern); - builder.Append("$"); + builder.Append('$'); return builder.ToString(); } - private string GetProductVersion() + private static string GetProductVersion() { FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location); return fileVersionInfo.ProductVersion ?? ProductVersionDefault; } - } - internal struct TableInfo(string type) - { - public string Type { get; } = type; + protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, string? port) + { + // Uri property takes precedent. + if (!string.IsNullOrWhiteSpace(uri)) + { + var uriValue = new Uri(uri); + if (uriValue.Scheme != Uri.UriSchemeHttp && uriValue.Scheme != Uri.UriSchemeHttps) + throw new ArgumentOutOfRangeException( + AdbcOptions.Uri, + uri, + $"Unsupported scheme '{uriValue.Scheme}'"); + return uriValue; + } - public List ColumnName { get; } = new(); + bool isPortSet = !string.IsNullOrEmpty(port); + bool isValidPortNumber = int.TryParse(port, out int portNumber) && portNumber > 0; + bool isDefaultHttpsPort = !isPortSet || (isValidPortNumber && portNumber == 443); + string uriScheme = isDefaultHttpsPort ? Uri.UriSchemeHttps : Uri.UriSchemeHttp; + int uriPort; + if (!isPortSet) + uriPort = -1; + else if (isValidPortNumber) + uriPort = portNumber; + else + throw new ArgumentOutOfRangeException(nameof(port), portNumber, $"Port number is not in a valid range."); - public List ColType { get; } = new(); + Uri baseAddress = new UriBuilder(uriScheme, hostName, uriPort, path).Uri; + return baseAddress; + } - public List BaseTypeName { get; } = new(); + protected abstract void ValidateConnection(); + protected abstract void ValidateAuthentication(); + protected abstract void ValidateOptions(); - public List TypeName { get; } = new(); + protected abstract Task GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default); + protected abstract Task GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + protected abstract Task GetRowSetAsync(TGetCatalogsResp getCatalogsResp, CancellationToken cancellationToken = default); + protected abstract Task GetRowSetAsync(TGetSchemasResp getSchemasResp, CancellationToken cancellationToken = default); + protected abstract Task GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default); + protected abstract Task GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default); + protected abstract Task GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default); - public List Nullable { get; } = new(); + internal abstract SparkServerType ServerType { get; } - public List Precision { get; } = new(); + internal struct TableInfo(string type) + { + public string Type { get; } = type; - public List Scale { get; } = new(); + public List ColumnName { get; } = new(); - public List OrdinalPosition { get; } = new(); + public List ColType { get; } = new(); - public List ColumnDefault { get; } = new(); + public List BaseTypeName { get; } = new(); - public List IsNullable { get; } = new(); + public List TypeName { get; } = new(); - public List IsAutoIncrement { get; } = new(); - } + public List Nullable { get; } = new(); - internal class SparkInfoArrowStream : IArrowArrayStream - { - private Schema schema; - private RecordBatch? batch; + public List Precision { get; } = new(); - public SparkInfoArrowStream(Schema schema, IReadOnlyList data) - { - this.schema = schema; - this.batch = new RecordBatch(schema, data, data[0].Length); - } + public List Scale { get; } = new(); - public Schema Schema { get { return this.schema; } } + public List OrdinalPosition { get; } = new(); - public ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) - { - RecordBatch? batch = this.batch; - this.batch = null; - return new ValueTask(batch); + public List ColumnDefault { get; } = new(); + + public List IsNullable { get; } = new(); + + public List IsAutoIncrement { get; } = new(); } - public void Dispose() + internal class SparkInfoArrowStream : IArrowArrayStream { - this.batch?.Dispose(); - this.batch = null; + private Schema schema; + private RecordBatch? batch; + + public SparkInfoArrowStream(Schema schema, IReadOnlyList data) + { + this.schema = schema; + this.batch = new RecordBatch(schema, data, data[0].Length); + } + + public Schema Schema { get { return this.schema; } } + + public ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) + { + RecordBatch? batch = this.batch; + this.batch = null; + return new ValueTask(batch); + } + + public void Dispose() + { + this.batch?.Dispose(); + this.batch = null; + } } } } diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs b/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs new file mode 100644 index 0000000000..7e432289e0 --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs @@ -0,0 +1,41 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal class SparkConnectionFactory + { + public static SparkConnection NewConnection(IReadOnlyDictionary properties) + { + bool _ = properties.TryGetValue(SparkParameters.Type, out string? type) && string.IsNullOrEmpty(type); + bool __ = ServerTypeParser.TryParse(type, out SparkServerType serverTypeValue); + return serverTypeValue switch + { + SparkServerType.Databricks => new SparkDatabricksConnection(properties), + SparkServerType.Http => new SparkHttpConnection(properties), + // TODO: Re-enable when properly supported + //SparkServerType.Standard => new SparkStandardConnection(properties), + SparkServerType.Empty => throw new ArgumentException($"Required property '{SparkParameters.Type}' is missing. Supported types: {ServerTypeParser.SupportedList}", nameof(properties)), + _ => throw new ArgumentOutOfRangeException(nameof(properties), $"Unsupported or unknown value '{type}' given for property '{SparkParameters.Type}'. Supported types: {ServerTypeParser.SupportedList}"), + }; + } + + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs index 92687e40d9..02fea3f5ef 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs @@ -15,7 +15,9 @@ * limitations under the License. */ +using System; using System.Collections.Generic; +using System.Linq; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { @@ -28,9 +30,15 @@ public SparkDatabase(IReadOnlyDictionary properties) this.properties = properties; } - public override AdbcConnection Connect(IReadOnlyDictionary? properties) + public override AdbcConnection Connect(IReadOnlyDictionary? options) { - SparkConnection connection = new SparkConnection(this.properties); + // connection options takes precedence over database properties for the same option + IReadOnlyDictionary mergedProperties = options == null + ? properties + : options + .Concat(properties.Where(x => !options.Keys.Contains(x.Key, StringComparer.OrdinalIgnoreCase))) + .ToDictionary(kvp => kvp.Key, kvp => kvp.Value); + SparkConnection connection = SparkConnectionFactory.NewConnection(mergedProperties); // new SparkConnection(mergedProperties); connection.OpenAsync().Wait(); return connection; } diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs new file mode 100644 index 0000000000..d51ef42b9b --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs @@ -0,0 +1,67 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Ipc; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal class SparkDatabricksConnection : SparkHttpConnection + { + public SparkDatabricksConnection(IReadOnlyDictionary properties) : base(properties) + { + } + + internal override IArrowArrayStream NewReader(T statement, Schema schema) => new SparkDatabricksReader(statement, schema); + + internal override SchemaParser SchemaParser => new SparkDatabricksSchemaParser(); + + internal override SparkServerType ServerType => SparkServerType.Databricks; + + protected override TOpenSessionReq CreateSessionRequest() + { + var req = new TOpenSessionReq(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7) + { + CanUseMultipleCatalogs = true, + }; + return req; + } + + protected override Task GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => + Task.FromResult(response.DirectResults.ResultSetMetadata); + protected override Task GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => + Task.FromResult(response.DirectResults.ResultSetMetadata); + protected override Task GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => + Task.FromResult(response.DirectResults.ResultSetMetadata); + protected override Task GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => + Task.FromResult(response.DirectResults.ResultSetMetadata); + + protected override Task GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default) => + Task.FromResult(response.DirectResults.ResultSet.Results); + protected override Task GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => + Task.FromResult(response.DirectResults.ResultSet.Results); + protected override Task GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => + Task.FromResult(response.DirectResults.ResultSet.Results); + protected override Task GetRowSetAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => + Task.FromResult(response.DirectResults.ResultSet.Results); + protected override Task GetRowSetAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => + Task.FromResult(response.DirectResults.ResultSet.Results); + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs new file mode 100644 index 0000000000..059ab1690b --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs @@ -0,0 +1,86 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Ipc; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal sealed class SparkDatabricksReader : IArrowArrayStream + { + HiveServer2Statement? statement; + Schema schema; + List? batches; + int index; + IArrowReader? reader; + + public SparkDatabricksReader(HiveServer2Statement statement, Schema schema) + { + this.statement = statement; + this.schema = schema; + } + + public Schema Schema { get { return schema; } } + + public async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) + { + while (true) + { + if (this.reader != null) + { + RecordBatch? next = await this.reader.ReadNextRecordBatchAsync(cancellationToken); + if (next != null) + { + return next; + } + this.reader = null; + } + + if (this.batches != null && this.index < this.batches.Count) + { + this.reader = new ArrowStreamReader(new ChunkStream(this.schema, this.batches[this.index++].Batch)); + continue; + } + + this.batches = null; + this.index = 0; + + if (this.statement == null) + { + return null; + } + + TFetchResultsReq request = new TFetchResultsReq(this.statement.OperationHandle, TFetchOrientation.FETCH_NEXT, this.statement.BatchSize); + TFetchResultsResp response = await this.statement.Connection.Client!.FetchResults(request, cancellationToken); + this.batches = response.Results.ArrowBatches; + + if (!response.HasMoreRows) + { + this.statement = null; + } + } + } + + public void Dispose() + { + } + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksSchemaParser.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabricksSchemaParser.cs new file mode 100644 index 0000000000..995c1edf09 --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksSchemaParser.cs @@ -0,0 +1,58 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Types; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal class SparkDatabricksSchemaParser : SchemaParser + { + public override IArrowType GetArrowType(TPrimitiveTypeEntry thriftType, DataTypeConversion dataTypeConversion) + { + return thriftType.Type switch + { + TTypeId.BIGINT_TYPE => Int64Type.Default, + TTypeId.BINARY_TYPE => BinaryType.Default, + TTypeId.BOOLEAN_TYPE => BooleanType.Default, + TTypeId.DATE_TYPE => Date32Type.Default, + TTypeId.DOUBLE_TYPE => DoubleType.Default, + TTypeId.FLOAT_TYPE => FloatType.Default, + TTypeId.INT_TYPE => Int32Type.Default, + TTypeId.NULL_TYPE => NullType.Default, + TTypeId.SMALLINT_TYPE => Int16Type.Default, + TTypeId.TIMESTAMP_TYPE => new TimestampType(TimeUnit.Microsecond, (string?)null), + TTypeId.TINYINT_TYPE => Int8Type.Default, + TTypeId.DECIMAL_TYPE => NewDecima128Type(thriftType), + TTypeId.CHAR_TYPE + or TTypeId.STRING_TYPE + or TTypeId.VARCHAR_TYPE + or TTypeId.INTERVAL_DAY_TIME_TYPE + or TTypeId.INTERVAL_YEAR_MONTH_TYPE + or TTypeId.ARRAY_TYPE + or TTypeId.MAP_TYPE + or TTypeId.STRUCT_TYPE + or TTypeId.UNION_TYPE + or TTypeId.USER_DEFINED_TYPE => StringType.Default, + TTypeId.TIMESTAMPLOCALTZ_TYPE => throw new NotImplementedException(), + _ => throw new NotImplementedException(), + }; + } + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs new file mode 100644 index 0000000000..4c068aaa57 --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs @@ -0,0 +1,271 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Security; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Ipc; +using Apache.Hive.Service.Rpc.Thrift; +using Thrift; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal class SparkHttpConnection : SparkConnection + { + private const string BasicAuthenticationScheme = "Basic"; + private const string BearerAuthenticationScheme = "Bearer"; + + public SparkHttpConnection(IReadOnlyDictionary properties) : base(properties) + { + } + + protected override void ValidateAuthentication() + { + // Validate authentication parameters + Properties.TryGetValue(SparkParameters.Token, out string? token); + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(SparkParameters.AuthType, out string? authType); + bool isValidAuthType = AuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue); + switch (authTypeValue) + { + case SparkAuthType.Token: + if (string.IsNullOrWhiteSpace(token)) + throw new ArgumentException( + $"Parameter '{SparkParameters.AuthType}' is set to '{SparkAuthTypeConstants.Token}' but parameter '{SparkParameters.Token}' is not set. Please provide a value for '{SparkParameters.Token}'.", + nameof(Properties)); + break; + case SparkAuthType.Basic: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameter '{SparkParameters.AuthType}' is set to '{SparkAuthTypeConstants.Basic}' but parameters '{AdbcOptions.Username}' or '{AdbcOptions.Password}' are not set. Please provide a values for these parameters.", + nameof(Properties)); + break; + case SparkAuthType.UsernameOnly: + if (string.IsNullOrWhiteSpace(username)) + throw new ArgumentException( + $"Parameter '{SparkParameters.AuthType}' is set to '{SparkAuthTypeConstants.UsernameOnly}' but parameter '{AdbcOptions.Username}' is not set. Please provide a values for this parameter.", + nameof(Properties)); + break; + case SparkAuthType.None: + break; + case SparkAuthType.Empty: + if (string.IsNullOrWhiteSpace(token) && (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password))) + throw new ArgumentException( + $"Parameters must include valid authentiation settings. Please provide either '{SparkParameters.Token}'; or '{AdbcOptions.Username}' and '{AdbcOptions.Password}'.", + nameof(Properties)); + break; + default: + throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); + } + } + + protected override void ValidateConnection() + { + // HostName or Uri is required parameter + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + Properties.TryGetValue(SparkParameters.HostName, out string? hostName); + if ((Uri.CheckHostName(hostName) == UriHostNameType.Unknown) + && (string.IsNullOrEmpty(uri) || !Uri.TryCreate(uri, UriKind.Absolute, out Uri? _))) + { + throw new ArgumentException( + $"Required parameter '{SparkParameters.HostName}' or '{AdbcOptions.Uri}' is missing or invalid. Please provide a valid hostname or URI for the data source.", + nameof(Properties)); + } + + // Validate port range + Properties.TryGetValue(SparkParameters.Port, out string? port); + if (int.TryParse(port, out int portNumber) && (portNumber <= IPEndPoint.MinPort || portNumber > IPEndPoint.MaxPort)) + throw new ArgumentOutOfRangeException( + nameof(Properties), + port, + $"Parameter '{SparkParameters.Port}' value is not in the valid range of 1 .. {IPEndPoint.MaxPort}."); + + // Ensure the parameters will produce a valid address + Properties.TryGetValue(SparkParameters.Path, out string? path); + _ = new HttpClient() + { + BaseAddress = GetBaseAddress(uri, hostName, path, port) + }; + } + + protected override void ValidateOptions() + { + Properties.TryGetValue(SparkParameters.DataTypeConv, out string? dataTypeConv); + DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv); + Properties.TryGetValue(SparkParameters.TLSOptions, out string? tlsOptions); + TlsOptions = TlsOptionsParser.Parse(tlsOptions); + Properties.TryGetValue(SparkParameters.ConnectTimeoutMilliseconds, out string? connectTimeoutMs); + if (connectTimeoutMs != null) + { + ConnectTimeoutMilliseconds = int.TryParse(connectTimeoutMs, NumberStyles.Integer, CultureInfo.InvariantCulture, out int connectTimeoutMsValue) && (connectTimeoutMsValue >= 0) + ? connectTimeoutMsValue + : throw new ArgumentOutOfRangeException(SparkParameters.ConnectTimeoutMilliseconds, connectTimeoutMs, $"must be a value of 0 (infinite) or between 1 .. {int.MaxValue}. default is 30000 milliseconds."); + } + } + + internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); + + protected override TTransport CreateTransport() + { + // Assumption: parameters have already been validated. + Properties.TryGetValue(SparkParameters.HostName, out string? hostName); + Properties.TryGetValue(SparkParameters.Path, out string? path); + Properties.TryGetValue(SparkParameters.Port, out string? port); + Properties.TryGetValue(SparkParameters.AuthType, out string? authType); + bool isValidAuthType = AuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue); + Properties.TryGetValue(SparkParameters.Token, out string? token); + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + + Uri baseAddress = GetBaseAddress(uri, hostName, path, port); + AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, token, username, password); + + HttpClientHandler httpClientHandler = NewHttpClientHandler(); + HttpClient httpClient = new(httpClientHandler); + httpClient.BaseAddress = baseAddress; + httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue; + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent); + httpClient.DefaultRequestHeaders.AcceptEncoding.Clear(); + httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity")); + httpClient.DefaultRequestHeaders.ExpectContinue = false; + + TConfiguration config = new(); + ThriftHttpTransport transport = new(httpClient, config) + { + // This value can only be set before the first call/request. So if a new value for query timeout + // is set, we won't be able to update the value. Setting to ~infinite and relying on cancellation token + // to ensure cancelled correctly. + ConnectTimeout = int.MaxValue, + }; + return transport; + } + + private HttpClientHandler NewHttpClientHandler() + { + HttpClientHandler httpClientHandler = new(); + if (TlsOptions != HiveServer2TlsOption.Empty) + { + httpClientHandler.ServerCertificateCustomValidationCallback = (request, certificate, chain, policyErrors) => + { + if (policyErrors == SslPolicyErrors.None) return true; + + return + (!policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors) || TlsOptions.HasFlag(HiveServer2TlsOption.AllowSelfSigned)) + && (!policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch) || TlsOptions.HasFlag(HiveServer2TlsOption.AllowHostnameMismatch)); + }; + } + + return httpClientHandler; + } + + private static AuthenticationHeaderValue? GetAuthenticationHeaderValue(SparkAuthType authType, string? token, string? username, string? password) + { + if (!string.IsNullOrEmpty(token) && (authType == SparkAuthType.Empty || authType == SparkAuthType.Token)) + { + return new AuthenticationHeaderValue(BearerAuthenticationScheme, token); + } + else if (!string.IsNullOrEmpty(username) && !string.IsNullOrEmpty(password) && (authType == SparkAuthType.Empty || authType == SparkAuthType.Basic)) + { + return new AuthenticationHeaderValue(BasicAuthenticationScheme, Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:{password}"))); + } + else if (!string.IsNullOrEmpty(username) && (authType == SparkAuthType.Empty || authType == SparkAuthType.UsernameOnly)) + { + return new AuthenticationHeaderValue(BasicAuthenticationScheme, Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:"))); + } + else if (authType == SparkAuthType.None) + { + return null; + } + else + { + throw new AdbcException("Missing connection properties. Must contain 'token' or 'username' and 'password'"); + } + } + + protected override async Task CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default) + { + if (!transport.IsOpen) await transport.OpenAsync(cancellationToken); + return new TBinaryProtocol(transport); + } + + protected override TOpenSessionReq CreateSessionRequest() + { + var req = new TOpenSessionReq(TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11) + { + CanUseMultipleCatalogs = true, + }; + return req; + } + + protected override Task GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => + GetResultSetMetadataAsync(response.OperationHandle, Client, cancellationToken); + protected override Task GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => + GetResultSetMetadataAsync(response.OperationHandle, Client, cancellationToken); + protected override Task GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => + GetResultSetMetadataAsync(response.OperationHandle, Client, cancellationToken); + protected override Task GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => + GetResultSetMetadataAsync(response.OperationHandle, Client, cancellationToken); + protected override Task GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); + protected override Task GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); + protected override Task GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); + protected override Task GetRowSetAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); + protected override Task GetRowSetAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); + + private async Task FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default) + { + await PollForResponseAsync(operationHandle, Client, PollTimeMillisecondsDefault, cancellationToken); + + TFetchResultsResp fetchResp = await FetchNextAsync(operationHandle, Client, batchSize, cancellationToken); + if (fetchResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(fetchResp.Status.ErrorMessage) + .SetNativeError(fetchResp.Status.ErrorCode) + .SetSqlState(fetchResp.Status.SqlState); + } + return fetchResp.Results; + } + + internal static async Task FetchNextAsync(TOperationHandle operationHandle, TCLIService.IAsync client, long batchSize, CancellationToken cancellationToken = default) + { + TFetchResultsReq request = new(operationHandle, TFetchOrientation.FETCH_NEXT, batchSize); + TFetchResultsResp response = await client.FetchResults(request, cancellationToken); + return response; + } + + internal override SchemaParser SchemaParser => new HiveServer2SchemaParser(); + + internal override SparkServerType ServerType => SparkServerType.Http; + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs index 4843d13b4d..6cb96dd5f1 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs @@ -20,12 +20,31 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark /// /// Parameters used for connecting to Spark data sources. /// - public class SparkParameters + public static class SparkParameters { public const string HostName = "adbc.spark.host"; public const string Port = "adbc.spark.port"; public const string Path = "adbc.spark.path"; public const string Token = "adbc.spark.token"; - public const string Password = "password"; + public const string AuthType = "adbc.spark.auth_type"; + public const string Type = "adbc.spark.type"; + public const string DataTypeConv = "adbc.spark.data_type_conv"; + public const string TLSOptions = "adbc.spark.tls_options"; + public const string ConnectTimeoutMilliseconds = "adbc.spark.connect_timeout_ms"; + } + + public static class SparkAuthTypeConstants + { + public const string None = "none"; + public const string UsernameOnly = "username_only"; + public const string Basic = "basic"; + public const string Token = "token"; + } + + public static class SparkServerTypeConstants + { + public const string Http = "http"; + public const string Databricks = "databricks"; + public const string Standard = "standard"; } } diff --git a/csharp/src/Drivers/Apache/Spark/SparkServerType.cs b/csharp/src/Drivers/Apache/Spark/SparkServerType.cs new file mode 100644 index 0000000000..351a2a0b9d --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkServerType.cs @@ -0,0 +1,56 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal enum SparkServerType + { + Invalid = 0, + Http, + Databricks, + Standard, + Empty = int.MaxValue, + } + + internal static class ServerTypeParser + { + internal const string SupportedList = SparkServerTypeConstants.Http + ", " + SparkServerTypeConstants.Databricks; + + internal static bool TryParse(string? serverType, out SparkServerType serverTypeValue) + { + switch (serverType?.Trim().ToLowerInvariant()) + { + case null: + case "": + serverTypeValue = SparkServerType.Empty; + return true; + case SparkServerTypeConstants.Databricks: + serverTypeValue = SparkServerType.Databricks; + return true; + case SparkServerTypeConstants.Http: + serverTypeValue = SparkServerType.Http; + return true; + case SparkServerTypeConstants.Standard: + serverTypeValue = SparkServerType.Standard; + return true; + default: + serverTypeValue = SparkServerType.Invalid; + return false; + } + } + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs new file mode 100644 index 0000000000..c8ab5772c9 --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs @@ -0,0 +1,137 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Apache.Hive.Service.Rpc.Thrift; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal class SparkStandardConnection : SparkHttpConnection + { + public SparkStandardConnection(IReadOnlyDictionary properties) : base(properties) + { + } + + protected override void ValidateAuthentication() + { + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(SparkParameters.AuthType, out string? authType); + bool isValidAuthType = AuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue); + switch (authTypeValue) + { + case SparkAuthType.None: + break; + case SparkAuthType.UsernameOnly: + if (string.IsNullOrWhiteSpace(username)) + throw new ArgumentException( + $"Parameter '{SparkParameters.AuthType}' is set to '{SparkAuthTypeConstants.UsernameOnly}' but parameters '{AdbcOptions.Username}' is not set. Please provide a value for this parameter.", + nameof(Properties)); + break; + case SparkAuthType.Basic: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameter '{SparkParameters.AuthType}' is set to '{SparkAuthTypeConstants.Basic}' but parameters '{AdbcOptions.Username}' or '{AdbcOptions.Password}' are not set. Please provide a values for these parameters.", + nameof(Properties)); + break; + case SparkAuthType.Empty: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameters must include valid authentiation settings. Please provide either '{SparkParameters.Token}'; or '{AdbcOptions.Username}' and '{AdbcOptions.Password}'.", + nameof(Properties)); + break; + default: + throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); + } + } + + protected override void ValidateConnection() + { + // HostName is required parameter + Properties.TryGetValue(SparkParameters.HostName, out string? hostName); + if (Uri.CheckHostName(hostName) == UriHostNameType.Unknown) + { + throw new ArgumentException( + $"Required parameter '{SparkParameters.HostName}' is missing or invalid. Please provide a valid hostname for the data source.", + nameof(Properties)); + } + + // Validate port range + Properties.TryGetValue(SparkParameters.Port, out string? port); + if (int.TryParse(port, out int portNumber) && (portNumber <= IPEndPoint.MinPort || portNumber > IPEndPoint.MaxPort)) + throw new ArgumentOutOfRangeException( + nameof(Properties), + port, + $"Parameter '{SparkParameters.Port}' value is not in the valid range of 1 .. {IPEndPoint.MaxPort}."); + + } + + protected override TTransport CreateTransport() + { + // Assumption: hostName and port have already been validated. + Properties.TryGetValue(SparkParameters.HostName, out string? hostName); + Properties.TryGetValue(SparkParameters.Port, out string? port); + + // Delay the open connection until later. + bool connectClient = false; + ThriftSocketTransport transport = new(hostName!, int.Parse(port!), connectClient, config: new()); + return transport; + } + + protected override async Task CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default) + { + return await base.CreateProtocolAsync(transport, cancellationToken); + + //if (!transport.IsOpen) await transport.OpenAsync(CancellationToken.None); + //return new TBinaryProtocol(transport); + } + + protected override TOpenSessionReq CreateSessionRequest() + { + // Assumption: user name and password have already been validated. + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(SparkParameters.AuthType, out string? authType); + bool isValidAuthType = AuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue); + TOpenSessionReq request = base.CreateSessionRequest(); + switch (authTypeValue) + { + case SparkAuthType.UsernameOnly: + case SparkAuthType.Basic: + case SparkAuthType.Empty when !string.IsNullOrEmpty(username): + request.Username = username!; + break; + } + switch (authTypeValue) + { + case SparkAuthType.Basic: + case SparkAuthType.Empty when !string.IsNullOrEmpty(password): + request.Password = password!; + break; + } + return request; + } + + internal override SparkServerType ServerType => SparkServerType.Standard; + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs index 4ab67bdcac..25888b1a3b 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs @@ -15,22 +15,30 @@ * limitations under the License. */ -using System; using System.Collections.Generic; -using System.Net.Http; -using System.Threading; -using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; -using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { - public class SparkStatement : HiveServer2Statement + internal class SparkStatement : HiveServer2Statement { internal SparkStatement(SparkConnection connection) : base(connection) { + foreach (KeyValuePair kvp in connection.Properties) + { + switch (kvp.Key) + { + case Options.BatchSize: + case Options.PollTimeMilliseconds: + case Options.QueryTimeoutSeconds: + { + SetOption(kvp.Key, kvp.Value); + break; + } + } + } } protected override void SetStatementProperties(TExecuteStatementReq statement) @@ -38,7 +46,9 @@ protected override void SetStatementProperties(TExecuteStatementReq statement) // TODO: Ensure this is set dynamically depending on server capabilities. statement.EnforceResultPersistenceMode = false; statement.ResultPersistenceMode = 2; - + // This seems like a good idea to have the server timeout so it doesn't keep processing unnecessarily. + // Set in combination with a CancellationToken. + statement.QueryTimeout = QueryTimeoutSeconds; statement.CanReadArrowResult = true; statement.CanDownloadResult = true; statement.ConfOverlay = SparkConnection.timestampConfig; @@ -55,75 +65,12 @@ protected override void SetStatementProperties(TExecuteStatementReq statement) }; } - protected override IArrowArrayStream NewReader(T statement, Schema schema) => new SparkReader(statement, schema); - /// /// Provides the constant string key values to the method. /// - public new sealed class Options : HiveServer2Statement.Options + public sealed class Options : ApacheParameters { // options specific to Spark go here } - - sealed class SparkReader : IArrowArrayStream - { - HiveServer2Statement? statement; - Schema schema; - List? batches; - int index; - IArrowReader? reader; - - public SparkReader(HiveServer2Statement statement, Schema schema) - { - this.statement = statement; - this.schema = schema; - } - - public Schema Schema { get { return schema; } } - - public async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) - { - while (true) - { - if (this.reader != null) - { - RecordBatch? next = await this.reader.ReadNextRecordBatchAsync(cancellationToken); - if (next != null) - { - return next; - } - this.reader = null; - } - - if (this.batches != null && this.index < this.batches.Count) - { - this.reader = new ArrowStreamReader(new ChunkStream(this.schema, this.batches[this.index++].Batch)); - continue; - } - - this.batches = null; - this.index = 0; - - if (this.statement == null) - { - return null; - } - - TFetchResultsReq request = new TFetchResultsReq(this.statement.operationHandle, TFetchOrientation.FETCH_NEXT, this.statement.BatchSize); - TFetchResultsResp response = await this.statement.connection.client!.FetchResults(request, cancellationToken); - this.batches = response.Results.ArrowBatches; - - if (!response.HasMoreRows) - { - this.statement = null; - } - } - } - - public void Dispose() - { - } - } - } } diff --git a/csharp/src/Drivers/Apache/Spark/SqlTypeNameParser.cs b/csharp/src/Drivers/Apache/Spark/SqlTypeNameParser.cs index 915fb1a191..ab56703eb1 100644 --- a/csharp/src/Drivers/Apache/Spark/SqlTypeNameParser.cs +++ b/csharp/src/Drivers/Apache/Spark/SqlTypeNameParser.cs @@ -16,29 +16,133 @@ */ using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; using System.Text.RegularExpressions; using Apache.Arrow.Types; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { + /// + /// Interface for the SQL type name parser. + /// + internal interface ISqlTypeNameParser + { + /// + /// Tries to parse the input string for a valid SQL type definition. + /// + /// The SQL type defintion string to parse. + /// If successful, the result; otherwise null. + /// True if it can successfully parse the type definition input string; otherwise false. + bool TryParse(string input, out SqlTypeNameParserResult? result); + } + /// /// Abstract and generic SQL data type name parser. /// - /// The type when returning a successful parse - internal abstract class SqlTypeNameParser where T : ParserResult + /// The type when returning a successful parse + internal abstract class SqlTypeNameParser : ISqlTypeNameParser where T : SqlTypeNameParserResult { + private static readonly ConcurrentDictionary s_cache = new(); + + private static readonly IReadOnlyDictionary s_parserMap = new Dictionary() + { + { (int)SparkConnection.ColumnTypeId.ARRAY, SqlArrayTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.BIGINT, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.BIGINT.ToString()) }, + { (int)SparkConnection.ColumnTypeId.BIT, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.BIT.ToString()) }, + { (int)SparkConnection.ColumnTypeId.BINARY, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.BINARY.ToString()) }, + { (int)SparkConnection.ColumnTypeId.BLOB, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.BLOB.ToString()) }, + { (int)SparkConnection.ColumnTypeId.BOOLEAN, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.BOOLEAN.ToString()) }, + { (int)SparkConnection.ColumnTypeId.CHAR, SqlCharTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.CLOB, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.CLOB.ToString()) }, + { (int)SparkConnection.ColumnTypeId.DATALINK, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.DATALINK.ToString()) }, + { (int)SparkConnection.ColumnTypeId.DATE, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.DATE.ToString()) }, + { (int)SparkConnection.ColumnTypeId.DECIMAL, SqlDecimalTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.DISTINCT, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.DISTINCT.ToString()) }, + { (int)SparkConnection.ColumnTypeId.DOUBLE, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.DOUBLE.ToString()) }, + { (int)SparkConnection.ColumnTypeId.FLOAT, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.FLOAT.ToString()) }, + { (int)SparkConnection.ColumnTypeId.INTEGER, SqlIntegerTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.JAVA_OBJECT, SqlMapTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.LONGNVARCHAR, SqlVarcharTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.LONGVARCHAR, SqlVarcharTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.NCHAR, SqlCharTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.NCLOB, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.NCLOB.ToString()) }, + { (int)SparkConnection.ColumnTypeId.NULL, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.NULL.ToString()) }, + { (int)SparkConnection.ColumnTypeId.NUMERIC, SqlDecimalTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.NVARCHAR, SqlVarcharTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.OTHER, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.OTHER.ToString()) }, + { (int)SparkConnection.ColumnTypeId.REAL, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.REAL.ToString()) }, + { (int)SparkConnection.ColumnTypeId.REF_CURSOR, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.REF_CURSOR.ToString()) }, + { (int)SparkConnection.ColumnTypeId.REF, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.REF.ToString()) }, + { (int)SparkConnection.ColumnTypeId.ROWID, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.ROWID.ToString()) }, + { (int)SparkConnection.ColumnTypeId.SMALLINT, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.SMALLINT.ToString()) }, + { (int)SparkConnection.ColumnTypeId.STRUCT, SqlStructTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.TIME, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.TIME.ToString()) }, + { (int)SparkConnection.ColumnTypeId.TIME_WITH_TIMEZONE, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.TIME_WITH_TIMEZONE.ToString()) }, + { (int)SparkConnection.ColumnTypeId.TIMESTAMP, SqlTimestampTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.TIMESTAMP_WITH_TIMEZONE, SqlTimestampTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.TINYINT, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.TINYINT.ToString()) }, + { (int)SparkConnection.ColumnTypeId.VARCHAR, SqlVarcharTypeParser.Default }, + { (int)SparkConnection.ColumnTypeId.SQLXML, SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.SQLXML.ToString()) }, + }; + + // Note: the INTERVAL sql type does not have an associated column type id. + private static readonly HashSet s_parsers = s_parserMap.Values + .Concat([SqlIntervalTypeParser.Default, SqlSimpleTypeParser.Default("VOID")]) + .ToHashSet(); + + /// + /// Gets the base SQL type name without decoration or sub clauses + /// + public abstract string BaseTypeName { get; } + + /// + /// Parses the input type name string and produces a result. + /// When a matching parser is found that successfully parses the type name string, the result of that parse is returned. + /// If no parser is able to successfully match the input type name, + /// then a is thrown. + /// + /// The type name string to parse + /// If provided, the column type id is used as a hint to find the most likely matching parser. + /// + /// A parser result, from a successful match and parse. + /// + public static T Parse(string input, int? columnTypeIdHint = null) => + SqlTypeNameParser.TryParse(input, out SqlTypeNameParserResult? result, columnTypeIdHint) && result != null + ? CastResultOrThrow(input, result) + : throw new NotSupportedException($"Unsupported SQL type name: '{input}'"); + /// /// Gets the expression to parse the SQL type name /// - public abstract Regex Expression { get; } + protected abstract Regex Expression { get; } /// - /// Generates the successful result of a matching parse + /// Generates the successful result for a matching parse /// /// The original SQL type name /// The successful result /// - public abstract T GenerateResult(string input, Match match); + protected virtual T GenerateResult(string input, Match match) => (T)new SqlTypeNameParserResult(input, BaseTypeName); + + private static T CastResultOrThrow(string input, SqlTypeNameParserResult result) => + (result is T typedResult) + ? typedResult + : throw new InvalidCastException($"Cannot cast return type '{result.GetType().Name}' to type '{(typeof(T)).Name}' for input SQL type name: '{input}'."); + + /// + /// Tries to parse the input string for a valid SQL type definition. + /// + /// The SQL type defintion string to parse. + /// If successful, the result; otherwise null. + /// True if it can successfully parse the type definition input string; otherwise false. + bool ISqlTypeNameParser.TryParse(string input, out SqlTypeNameParserResult? result) + { + bool success = TryParse(input, out T? typedResult); + result = success ? typedResult : (SqlTypeNameParserResult?)default; + return success; + } /// /// Tries to parse the input string for a valid SQL type definition. @@ -46,7 +150,7 @@ internal abstract class SqlTypeNameParser where T : ParserResult /// The SQL type defintion string to parse. /// If successful, the result; otherwise null. /// True if it can successfully parse the type definition input string; otherwise false. - public bool TryParse(string input, out T? result) + internal bool TryParse(string input, out T? result) { Match match = Expression.Match(input); if (!match.Success) @@ -60,23 +164,57 @@ public bool TryParse(string input, out T? result) } /// - /// Parses the input string for a valid SQL type definition and returns the result or returns the defaultValue, if invalid. + /// Tries to parse the input SQL type name. If a matching parser is found and can parse the type name, it's result is set in parserResult and true is returned. + /// If a matching parser is not found parserResult is set to null and false is returned. /// - /// The SQL type defintion string to parse. - /// If input string is an invalid type definition, this result is returned instead. - /// If input string is a valid SQL type definition, it returns the result; otherwise defaultValue. - public T ParseOrDefault(string input, T defaultValue) + /// The SQL type name to parse + /// The result of a successful parse, null otherwise + /// The column type id as a hint to find the most appropriate parser + /// true if a matching parser is able to parse the SQL type name, false otherwise + internal static bool TryParse(string input, out SqlTypeNameParserResult? parserResult, int? columnTypeIdHint = null) { - return TryParse(input, out T? result) ? result! : defaultValue; + // Note: there may be multiple calls that successfully add/set the value in the cache + // - but the parser will produce the same result in each case. + string trimmedInput = input.Trim(); + if (s_cache.ContainsKey(trimmedInput)) + { + parserResult = s_cache[trimmedInput]; + return true; + } + + ISqlTypeNameParser? sqlTypeNameParser = null; + if (columnTypeIdHint != null && s_parserMap.ContainsKey(columnTypeIdHint.Value)) + { + sqlTypeNameParser = s_parserMap[columnTypeIdHint.Value]; + if (sqlTypeNameParser.TryParse(input, out SqlTypeNameParserResult? result) && result != null) + { + parserResult = result; + s_cache[trimmedInput] = result; + return true; + } + } + foreach (ISqlTypeNameParser parser in s_parsers) + { + if (parser == sqlTypeNameParser) continue; + if (parser.TryParse(input, out SqlTypeNameParserResult? result) && result != null) + { + parserResult = result; + s_cache[trimmedInput] = result; + return true; + } + } + + parserResult = null; + return false; } } /// - /// An result for parsing a SQL data type. + /// A result for parsing a SQL data type. /// /// The original SQL type name to parse /// The 'base' type name to use which is typically more simple without sub-clauses - internal class ParserResult(string typeName, string baseTypeName) + internal class SqlTypeNameParserResult(string typeName, string baseTypeName) { /// /// The original SQL type name @@ -87,6 +225,19 @@ internal class ParserResult(string typeName, string baseTypeName) /// The 'base' type name to use which is typically more simple without sub-clauses /// public string BaseTypeName { get; } = baseTypeName; + + public override bool Equals(object? obj) + { + if (ReferenceEquals(this, obj)) return true; + if (obj is not SqlTypeNameParserResult other) return false; + return TypeName.Equals(other.TypeName) + && BaseTypeName.Equals(other.BaseTypeName); + } + + public override int GetHashCode() + { + return TypeName.GetHashCode() ^ BaseTypeName.GetHashCode(); + } } /// @@ -95,12 +246,23 @@ internal class ParserResult(string typeName, string baseTypeName) /// The original SQL type name to parse /// The 'base' type name without the length clause /// The length of the column for this type name - internal class SqlCharVarcharParserResult(string typeName, string baseTypeName, int columnSize) : ParserResult(typeName, baseTypeName) + internal class SqlCharVarcharParserResult(string typeName, string baseTypeName, int columnSize = SqlVarcharTypeParser.VarcharColumnSizeDefault) : SqlTypeNameParserResult(typeName, baseTypeName) { /// /// The length of the column for this type name /// public int ColumnSize { get; } = columnSize; + + public override bool Equals(object? obj) => obj is SqlCharVarcharParserResult result + && base.Equals(obj) + && TypeName == result.TypeName + && BaseTypeName == result.BaseTypeName + && ColumnSize == result.ColumnSize; + + public override int GetHashCode() => base.GetHashCode() + ^ TypeName.GetHashCode() + ^ BaseTypeName.GetHashCode() + ^ ColumnSize.GetHashCode(); } /// @@ -110,7 +272,7 @@ internal class SqlCharVarcharParserResult(string typeName, string baseTypeName, /// The 'base' type name without the precision or scale clause /// The precision of the decimal type /// The scale (decimal digits) of the decimal type - internal class SqlDecimalParserResult(string typeName, string baseTypeName, int precision, int scale) : ParserResult(typeName, baseTypeName) + internal class SqlDecimalParserResult(string typeName, string baseTypeName, int precision, int scale) : SqlTypeNameParserResult(typeName, baseTypeName) { /// /// Constructs a new default result given the original type name. @@ -132,6 +294,37 @@ public SqlDecimalParserResult(string typeName) : this(typeName, "DECIMAL", SqlDe /// The representing the parsed type name /// public Decimal128Type Decimal128Type { get; } = new Decimal128Type(precision, scale); + + public override bool Equals(object? obj) => obj is SqlDecimalParserResult result + && base.Equals(obj) + && TypeName == result.TypeName + && BaseTypeName == result.BaseTypeName + && Precision == result.Precision + && Scale == result.Scale + && EqualityComparer.Default.Equals(Decimal128Type, result.Decimal128Type); + + public override int GetHashCode() => base.GetHashCode() + ^ TypeName.GetHashCode() + ^ BaseTypeName.GetHashCode() + ^ Precision.GetHashCode() + ^ Scale.GetHashCode() + ^ Decimal128Type.GetHashCode(); + } + + internal class SqlIntervalParserResult(string typeName, string baseTypeName, string qualifiers) : SqlTypeNameParserResult(typeName, baseTypeName) + { + public string Qualifiers { get; } = qualifiers; + + public override bool Equals(object? obj) => obj is SqlIntervalParserResult result + && base.Equals(obj) + && TypeName == result.TypeName + && BaseTypeName == result.BaseTypeName + && Qualifiers == result.Qualifiers; + + public override int GetHashCode() => base.GetHashCode() + ^ TypeName.GetHashCode() + ^ BaseTypeName.GetHashCode() + ^ Qualifiers.GetHashCode(); } /// @@ -139,15 +332,17 @@ public SqlDecimalParserResult(string typeName) : this(typeName, "DECIMAL", SqlDe /// internal class SqlCharTypeParser : SqlTypeNameParser { - private const string BaseTypeName = "CHAR"; + public static SqlCharTypeParser Default { get; } = new(); + + public override string BaseTypeName => "CHAR"; private static readonly Regex s_expression = new( @"^\s*(?((CHAR)|(NCHAR)))(\s*\(\s*(?\d{1,10})\s*\))\s*$", RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); - public override Regex Expression => s_expression; + protected override Regex Expression => s_expression; - public override SqlCharVarcharParserResult GenerateResult(string input, Match match) + protected override SqlCharVarcharParserResult GenerateResult(string input, Match match) { GroupCollection groups = match.Groups; Group precisionGroup = groups["precision"]; @@ -165,17 +360,19 @@ public override SqlCharVarcharParserResult GenerateResult(string input, Match ma internal class SqlVarcharTypeParser : SqlTypeNameParser { internal const int VarcharColumnSizeDefault = int.MaxValue; - - private const string VarcharBaseTypeName = "VARCHAR"; private const string StringBaseTypeName = "STRING"; + public static SqlVarcharTypeParser Default => new(); + + public override string BaseTypeName => "VARCHAR"; + private static readonly Regex s_expression = new( @"^\s*(?((STRING)|(VARCHAR)|(LONGVARCHAR)|(LONGNVARCHAR)|(NVARCHAR)))(\s*\(\s*(?\d{1,10})\s*\))?\s*$", RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); - public override Regex Expression => s_expression; + protected override Regex Expression => s_expression; - public override SqlCharVarcharParserResult GenerateResult(string input, Match match) + protected override SqlCharVarcharParserResult GenerateResult(string input, Match match) { GroupCollection groups = match.Groups; Group precisionGroup = groups["precision"]; @@ -183,7 +380,7 @@ public override SqlCharVarcharParserResult GenerateResult(string input, Match ma string baseTypeName = typeNameGroup.Value.Equals(StringBaseTypeName, StringComparison.InvariantCultureIgnoreCase) ? StringBaseTypeName - : VarcharBaseTypeName; + : BaseTypeName; int precision = precisionGroup.Success && int.TryParse(precisionGroup.Value, out int candidatePrecision) ? candidatePrecision : VarcharColumnSizeDefault; @@ -199,7 +396,9 @@ internal class SqlDecimalTypeParser : SqlTypeNameParser internal const int DecimalPrecisionDefault = 10; internal const int DecimalScaleDefault = 0; - private const string BaseTypeName = "DECIMAL"; + public static SqlDecimalTypeParser Default => new(); + + public override string BaseTypeName => "DECIMAL"; // Pattern is based on this definition // https://docs.databricks.com/en/sql/language-manual/data-types/decimal-type.html#syntax @@ -210,9 +409,9 @@ internal class SqlDecimalTypeParser : SqlTypeNameParser @"^\s*(?((DECIMAL)|(DEC)|(NUMERIC)))(\s*\(\s*((?\d{1,2})(\s*\,\s*(?\d{1,2}))?)\s*\))?\s*$", RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); - public override Regex Expression => s_expression; + protected override Regex Expression => s_expression; - public override SqlDecimalParserResult GenerateResult(string input, Match match) + protected override SqlDecimalParserResult GenerateResult(string input, Match match) { GroupCollection groups = match.Groups; Group precisionGroup = groups["precision"]; @@ -224,4 +423,138 @@ public override SqlDecimalParserResult GenerateResult(string input, Match match) return new SqlDecimalParserResult(input, BaseTypeName, precision, scale); } } + + /// + /// Provides a parser for SQL INTEGER type definitions. + /// + internal class SqlIntegerTypeParser : SqlTypeNameParser + { + public static SqlIntegerTypeParser Default => new(); + + public override string BaseTypeName => "INTEGER"; + + // Pattern is based on this definition + // https://docs.databricks.com/en/sql/language-manual/data-types/int-type.html#syntax + // { INT | INTEGER } + private static readonly Regex s_expression = new( + @"^\s*(?((INTEGER)|(INT)))\s*$", + RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); + + protected override Regex Expression => s_expression; + } + + /// + /// Provides a parser for SQL TIMESTAMP type definitions. + /// + internal class SqlTimestampTypeParser : SqlTypeNameParser + { + public static SqlTimestampTypeParser Default => new(); + + public override string BaseTypeName => "TIMESTAMP"; + + // Pattern is based on this definition + // https://docs.databricks.com/en/sql/language-manual/data-types/map-type.html#syntax + // MAP + // keyType: Any data type other than MAP specifying the keys. + // valueType: Any data type specifying the values. + private static readonly Regex s_expression = new( + @"^\s*(?((TIMESTAMP)|(TIMESTAMP_LTZ)|(TIMESTAMP_NTZ)))\s*$", + RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); + + protected override Regex Expression => s_expression; + } + + /// + /// Provides a parser for SQL STRUCT type definitions. + /// + internal class SqlStructTypeParser : SqlTypeNameParser + { + public static SqlStructTypeParser Default => new(); + + public override string BaseTypeName => "STRUCT"; + + // Pattern is based on this definition + // https://docs.databricks.com/en/sql/language-manual/data-types/struct-type.html#syntax + // STRUCT < [fieldName [:] fieldType [NOT NULL] [COMMENT str] [, …] ] > + // fieldName: An identifier naming the field. The names need not be unique. + // fieldType: Any data type. + // NOT NULL: When specified the struct guarantees that the value of this field is never NULL. + // COMMENT str: An optional string literal describing the field. + private static readonly Regex s_expression = new( + @"^\s*(?STRUCT)(?\s*\<(.+)\>)\s*$", // STUCT + RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); + + protected override Regex Expression => s_expression; + } + + /// + /// Provides a parser for SQL ARRAY type definitions. + /// + internal class SqlArrayTypeParser : SqlTypeNameParser + { + public static SqlArrayTypeParser Default => new(); + + public override string BaseTypeName => "ARRAY"; + + // Pattern is based on this definition + // https://docs.databricks.com/en/sql/language-manual/data-types/array-type.html#syntax + // ARRAY < elementType > + // elementType: Any data type defining the type of the elements of the array. + private static readonly Regex s_expression = new( + @"^\s*(?ARRAY)(?\s*\<(.+)\>)\s*$", + RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); + + protected override Regex Expression => s_expression; + } + + /// + /// Provides a parser for SQL MAP type definitions. + /// + internal class SqlMapTypeParser : SqlTypeNameParser + { + public static SqlMapTypeParser Default => new(); + + public override string BaseTypeName => "MAP"; + + // Pattern is based on this definition + // https://docs.databricks.com/en/sql/language-manual/data-types/map-type.html#syntax + // MAP + // keyType: Any data type other than MAP specifying the keys. + // valueType: Any data type specifying the values. + private static readonly Regex s_expression = new( + @"^\s*(?MAP)(?\s*\<(.+)\>)\s*$", + RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); + + protected override Regex Expression => s_expression; + } + + internal class SqlIntervalTypeParser : SqlTypeNameParser + { + public static SqlIntervalTypeParser Default => new(); + + public override string BaseTypeName { get; } = "INTERVAL"; + + // See: https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html#syntax + private static readonly Regex s_expression = new( + @"^\s*(?INTERVAL)\s+.*$", + RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); + + protected override Regex Expression => s_expression; + } + + internal class SqlSimpleTypeParser(string baseTypeName) : SqlTypeNameParser + { + private static readonly ConcurrentDictionary s_parserMap = new ConcurrentDictionary(); + + public static SqlSimpleTypeParser Default(string baseTypeName) + { + return s_parserMap.GetOrAdd(baseTypeName, (typeName) => new SqlSimpleTypeParser(typeName)); + } + + public override string BaseTypeName { get; } = baseTypeName; + + protected override Regex Expression => new( + @"^\s*" + Regex.Escape(BaseTypeName) + @"\s*$", + RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.CultureInvariant); + } } diff --git a/csharp/src/Drivers/Apache/Thrift/BitmapUtilities.cs b/csharp/src/Drivers/Apache/Thrift/BitmapUtilities.cs new file mode 100644 index 0000000000..ae23ac872c --- /dev/null +++ b/csharp/src/Drivers/Apache/Thrift/BitmapUtilities.cs @@ -0,0 +1,62 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +namespace Apache.Arrow.Adbc.Drivers.Apache.Thrift +{ + internal static class BitmapUtilities + { + private static readonly byte[] s_bitMasks = [0, 0b00000001, 0b00000011, 0b00000111, 0b00001111, 0b00011111, 0b00111111, 0b01111111, 0b11111111]; + + /// + /// Gets the "validity" bitmap buffer from a 'nulls' bitmap. + /// + /// The bitmap of rows where the value is a null value (i.e., "invalid") + /// The length of the array. + /// Returns the number of bits set in the bitmap. + /// A bitmap of "valid" rows (i.e., not null values). + /// Inverts the bits in the incoming bitmap to reverse the null to valid indicators. + internal static ArrowBuffer GetValidityBitmapBuffer(ref byte[] nulls, int arrayLength, out int nullCount) + { + nullCount = BitUtility.CountBits(nulls); + + int fullBytes = arrayLength / 8; + int remainingBits = arrayLength % 8; + int requiredBytes = fullBytes + (remainingBits == 0 ? 0 : 1); + if (nulls.Length < requiredBytes) + { + // Note: Spark may return a nulls bitmap buffer that is shorter than required - implying that missing bits indicate non-null. + // However, since we need to invert the bits and return a "validity" bitmap, we need to have a full length bitmap. + byte[] temp = new byte[requiredBytes]; + nulls.CopyTo(temp, 0); + nulls = temp; + } + + // Handle full bytes + for (int i = 0; i < fullBytes; i++) + { + nulls[i] = (byte)~nulls[i]; + } + // Handle remaing bits + if (remainingBits > 0) + { + int lastByteIndex = requiredBytes - 1; + nulls[lastByteIndex] = (byte)(s_bitMasks[remainingBits] & (byte)~nulls[lastByteIndex]); + } + return new ArrowBuffer(nulls); + } + } +} diff --git a/csharp/src/Drivers/Apache/Thrift/SchemaParser.cs b/csharp/src/Drivers/Apache/Thrift/SchemaParser.cs index 16a0f8356d..bee4a8485b 100644 --- a/csharp/src/Drivers/Apache/Thrift/SchemaParser.cs +++ b/csharp/src/Drivers/Apache/Thrift/SchemaParser.cs @@ -16,67 +16,38 @@ */ using System; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; namespace Apache.Arrow.Adbc.Drivers.Apache { - internal class SchemaParser + internal abstract class SchemaParser { - internal static Schema GetArrowSchema(TTableSchema thriftSchema) + internal Schema GetArrowSchema(TTableSchema thriftSchema, DataTypeConversion dataTypeConversion) { Field[] fields = new Field[thriftSchema.Columns.Count]; for (int i = 0; i < thriftSchema.Columns.Count; i++) { TColumnDesc column = thriftSchema.Columns[i]; // Note: no nullable metadata is returned from the Thrift interface. - fields[i] = new Field(column.ColumnName, GetArrowType(column.TypeDesc.Types[0]), nullable: true /* assumed */); + fields[i] = new Field(column.ColumnName, GetArrowType(column.TypeDesc.Types[0], dataTypeConversion), nullable: true /* assumed */); } return new Schema(fields, null); } - static IArrowType GetArrowType(TTypeEntry thriftType) + IArrowType GetArrowType(TTypeEntry thriftType, DataTypeConversion dataTypeConversion) { if (thriftType.PrimitiveEntry != null) { - return GetArrowType(thriftType.PrimitiveEntry); + return GetArrowType(thriftType.PrimitiveEntry, dataTypeConversion); } throw new InvalidOperationException(); } - public static IArrowType GetArrowType(TPrimitiveTypeEntry thriftType) - { - switch (thriftType.Type) - { - case TTypeId.BIGINT_TYPE: return Int64Type.Default; - case TTypeId.BINARY_TYPE: return BinaryType.Default; - case TTypeId.BOOLEAN_TYPE: return BooleanType.Default; - case TTypeId.CHAR_TYPE: return StringType.Default; - case TTypeId.DATE_TYPE: return Date32Type.Default; - case TTypeId.DOUBLE_TYPE: return DoubleType.Default; - case TTypeId.FLOAT_TYPE: return FloatType.Default; - case TTypeId.INT_TYPE: return Int32Type.Default; - case TTypeId.NULL_TYPE: return NullType.Default; - case TTypeId.SMALLINT_TYPE: return Int16Type.Default; - case TTypeId.STRING_TYPE: return StringType.Default; - case TTypeId.TIMESTAMP_TYPE: return new TimestampType(TimeUnit.Microsecond, (string?)null); - case TTypeId.TINYINT_TYPE: return Int8Type.Default; - case TTypeId.VARCHAR_TYPE: return StringType.Default; - case TTypeId.DECIMAL_TYPE: - int precision = thriftType.TypeQualifiers.Qualifiers["precision"].I32Value; - int scale = thriftType.TypeQualifiers.Qualifiers["scale"].I32Value; - return new Decimal128Type(precision, scale); - case TTypeId.INTERVAL_DAY_TIME_TYPE: - case TTypeId.INTERVAL_YEAR_MONTH_TYPE: - case TTypeId.ARRAY_TYPE: - case TTypeId.MAP_TYPE: - case TTypeId.STRUCT_TYPE: - case TTypeId.UNION_TYPE: - case TTypeId.USER_DEFINED_TYPE: - return StringType.Default; - default: - throw new NotImplementedException(); - } - } + public abstract IArrowType GetArrowType(TPrimitiveTypeEntry thriftType, DataTypeConversion dataTypeConversion); + + protected static Decimal128Type NewDecima128Type(TPrimitiveTypeEntry thriftType) => + new(thriftType.TypeQualifiers.Qualifiers["precision"].I32Value, thriftType.TypeQualifiers.Qualifiers["scale"].I32Value); } } diff --git a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TBinaryColumn.cs b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TBinaryColumn.cs index e5259a6532..87e92bdf8e 100644 --- a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TBinaryColumn.cs +++ b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TBinaryColumn.cs @@ -30,7 +30,6 @@ namespace Apache.Hive.Service.Rpc.Thrift public partial class TBinaryColumn : TBase { - public BinaryArray Values { get; set; } public TBinaryColumn() @@ -83,14 +82,12 @@ public TBinaryColumn DeepCopy() values = new ArrowBuffer.Builder(); int offset = 0; - offsetBuffer = new byte[(length + 1) * 4]; + offsetBuffer = new byte[(length + 1) * sizeof(int)]; var memory = offsetBuffer.AsMemory(); - var typedMemory = Unsafe.As, Memory>(ref memory).Slice(0, length + 1); for(int _i197 = 0; _i197 < length; ++_i197) { - //typedMemory.Span[_i197] = offset; - StreamExtensions.WriteInt32LittleEndian(offset, memory.Span, _i197 * 4); + StreamExtensions.WriteInt32LittleEndian(offset, memory.Span, _i197 * sizeof(int)); var size = await iprot.ReadI32Async(cancellationToken); offset += size; @@ -109,8 +106,7 @@ public TBinaryColumn DeepCopy() await transport.ReadExactlyAsync(tmp.AsMemory(0, size), cancellationToken); values.Append(tmp.AsMemory(0, size).Span); } - typedMemory.Span[length] = offset; - StreamExtensions.WriteInt32LittleEndian(offset, memory.Span, length * 4); + StreamExtensions.WriteInt32LittleEndian(offset, memory.Span, length * sizeof(int)); await iprot.ReadListEndAsync(cancellationToken); } @@ -150,7 +146,8 @@ public TBinaryColumn DeepCopy() throw new TProtocolException(TProtocolException.INVALID_DATA); } - Values = new BinaryArray(BinaryType.Default, length, new ArrowBuffer(offsetBuffer), values.Build(), new ArrowBuffer(nulls), BitUtility.CountBits(nulls)); + ArrowBuffer validityBitmapBuffer = BitmapUtilities.GetValidityBitmapBuffer(ref nulls, length, out int nullCount); + Values = new BinaryArray(BinaryType.Default, length, new ArrowBuffer(offsetBuffer), values.Build(), validityBitmapBuffer, nullCount); } finally { diff --git a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TBoolColumn.cs b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TBoolColumn.cs index 2288c4acd3..c635abc124 100644 --- a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TBoolColumn.cs +++ b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TBoolColumn.cs @@ -124,7 +124,8 @@ public TBoolColumn DeepCopy() throw new TProtocolException(TProtocolException.INVALID_DATA); } - Values = new BooleanArray(values.Build(), new ArrowBuffer(nulls), length, BitUtility.CountBits(nulls), 0); + ArrowBuffer validityBitmapBuffer = BitmapUtilities.GetValidityBitmapBuffer(ref nulls, length, out int nullCount); + Values = new BooleanArray(values.Build(), validityBitmapBuffer, length, nullCount, 0); } finally { diff --git a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TByteColumn.cs b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TByteColumn.cs index 5c7f129c34..3a67288429 100644 --- a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TByteColumn.cs +++ b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TByteColumn.cs @@ -120,8 +120,8 @@ public TByteColumn DeepCopy() { throw new TProtocolException(TProtocolException.INVALID_DATA); } - - Values = new Int8Array(new ArrowBuffer(buffer), new ArrowBuffer(nulls), length, BitUtility.CountBits(nulls), 0); + ArrowBuffer validityBitmapBuffer = BitmapUtilities.GetValidityBitmapBuffer(ref nulls, length, out int nullCount); + Values = new Int8Array(new ArrowBuffer(buffer), validityBitmapBuffer, length, nullCount, 0); } finally { diff --git a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TDoubleColumn.cs b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TDoubleColumn.cs index c87a05e0dd..24deb6c971 100644 --- a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TDoubleColumn.cs +++ b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TDoubleColumn.cs @@ -29,7 +29,6 @@ namespace Apache.Hive.Service.Rpc.Thrift public partial class TDoubleColumn : TBase { - public DoubleArray Values { get; set; } public TDoubleColumn() @@ -78,14 +77,13 @@ public TDoubleColumn DeepCopy() var _list178 = await iprot.ReadListBeginAsync(cancellationToken); length = _list178.Count; - buffer = new byte[length * 8]; + buffer = new byte[length * sizeof(double)]; var memory = buffer.AsMemory(); - var typedMemory = Unsafe.As, Memory>(ref memory).Slice(0, length); iprot.Transport.CheckReadBytesAvailable(buffer.Length); await transport.ReadExactlyAsync(memory, cancellationToken); for (int _i179 = 0; _i179 < length; ++_i179) { - typedMemory.Span[_i179] = BinaryPrimitives.ReverseEndianness(typedMemory.Span[_i179]); + StreamExtensions.ReverseEndianI64AtOffset(memory.Span, _i179 * sizeof(double)); } await iprot.ReadListEndAsync(cancellationToken); } @@ -125,7 +123,8 @@ public TDoubleColumn DeepCopy() throw new TProtocolException(TProtocolException.INVALID_DATA); } - Values = new DoubleArray(new ArrowBuffer(buffer), new ArrowBuffer(nulls), length, BitUtility.CountBits(nulls), 0); + ArrowBuffer validityBitmapBuffer = BitmapUtilities.GetValidityBitmapBuffer(ref nulls, length, out int nullCount); + Values = new DoubleArray(new ArrowBuffer(buffer), validityBitmapBuffer, length, nullCount, 0); } finally { diff --git a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI16Column.cs b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI16Column.cs index e8de6f0207..b300a5079a 100644 --- a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI16Column.cs +++ b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI16Column.cs @@ -78,15 +78,13 @@ public TI16Column DeepCopy() var _list151 = await iprot.ReadListBeginAsync(cancellationToken); length = _list151.Count; - buffer = new byte[length * 2]; + buffer = new byte[length * sizeof(short)]; var memory = buffer.AsMemory(); - var typedMemory = Unsafe.As, Memory>(ref memory).Slice(0, length); iprot.Transport.CheckReadBytesAvailable(buffer.Length); await transport.ReadExactlyAsync(memory, cancellationToken); for (int _i152 = 0; _i152 < length; ++_i152) { - //typedMemory.Span[_i152] = BinaryPrimitives.ReverseEndianness(typedMemory.Span[_i152]); - StreamExtensions.ReverseEndiannessInt16(memory.Span, _i152 * 2); + StreamExtensions.ReverseEndiannessInt16(memory.Span, _i152 * sizeof(short)); } await iprot.ReadListEndAsync(cancellationToken); } @@ -125,8 +123,8 @@ public TI16Column DeepCopy() { throw new TProtocolException(TProtocolException.INVALID_DATA); } - - Values = new Int16Array(new ArrowBuffer(buffer), new ArrowBuffer(nulls), length, BitUtility.CountBits(nulls), 0); + ArrowBuffer validityBitmapBuffer = BitmapUtilities.GetValidityBitmapBuffer(ref nulls, length, out int nullCount); + Values = new Int16Array(new ArrowBuffer(buffer), validityBitmapBuffer, length, nullCount, 0); } finally { diff --git a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI32Column.cs b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI32Column.cs index e896a9aaad..9530933ca4 100644 --- a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI32Column.cs +++ b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI32Column.cs @@ -77,15 +77,13 @@ public TI32Column DeepCopy() { var _list160 = await iprot.ReadListBeginAsync(cancellationToken); length = _list160.Count; - buffer = new byte[length * 4]; + buffer = new byte[length * sizeof(int)]; var memory = buffer.AsMemory(); - var typedMemory = Unsafe.As, Memory>(ref memory).Slice(0, length); iprot.Transport.CheckReadBytesAvailable(buffer.Length); await transport.ReadExactlyAsync(memory, cancellationToken); for (int _i161 = 0; _i161 < length; ++_i161) { - //typedMemory.Span[_i161] = BinaryPrimitives.ReverseEndianness(typedMemory.Span[_i161]); - StreamExtensions.ReverseEndianI32AtOffset(memory.Span, _i161 * 4); + StreamExtensions.ReverseEndianI32AtOffset(memory.Span, _i161 * sizeof(int)); } await iprot.ReadListEndAsync(cancellationToken); } @@ -124,8 +122,8 @@ public TI32Column DeepCopy() { throw new TProtocolException(TProtocolException.INVALID_DATA); } - - Values = new Int32Array(new ArrowBuffer(buffer), new ArrowBuffer(nulls), length, BitUtility.CountBits(nulls), 0); + ArrowBuffer validityBitmapBuffer = BitmapUtilities.GetValidityBitmapBuffer(ref nulls, length, out int nullCount); + Values = new Int32Array(new ArrowBuffer(buffer), validityBitmapBuffer, length, nullCount, 0); } finally { diff --git a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI64Column.cs b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI64Column.cs index 794da6e624..6de0b3a481 100644 --- a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI64Column.cs +++ b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TI64Column.cs @@ -29,7 +29,6 @@ namespace Apache.Hive.Service.Rpc.Thrift public partial class TI64Column : TBase { - public Int64Array Values { get; set; } public TI64Column() @@ -78,14 +77,13 @@ public TI64Column DeepCopy() var _list169 = await iprot.ReadListBeginAsync(cancellationToken); length = _list169.Count; - buffer = new byte[length * 8]; + buffer = new byte[length * sizeof(long)]; var memory = buffer.AsMemory(); - var typedMemory = Unsafe.As, Memory>(ref memory).Slice(0, length); iprot.Transport.CheckReadBytesAvailable(buffer.Length); await transport.ReadExactlyAsync(memory, cancellationToken); for (int _i170 = 0; _i170 < length; ++_i170) { - typedMemory.Span[_i170] = BinaryPrimitives.ReverseEndianness(typedMemory.Span[_i170]); + StreamExtensions.ReverseEndianI64AtOffset(memory.Span, _i170 * sizeof(long)); } await iprot.ReadListEndAsync(cancellationToken); } @@ -124,8 +122,8 @@ public TI64Column DeepCopy() { throw new TProtocolException(TProtocolException.INVALID_DATA); } - - Values = new Int64Array(new ArrowBuffer(buffer), new ArrowBuffer(nulls), length, BitUtility.CountBits(nulls), 0); + ArrowBuffer validityBitmapBuffer = BitmapUtilities.GetValidityBitmapBuffer(ref nulls, length, out int nullCount); + Values = new Int64Array(new ArrowBuffer(buffer), validityBitmapBuffer, length, nullCount, 0); } finally { diff --git a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TStringColumn.cs b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TStringColumn.cs index c90631ef8a..d0f691b14d 100644 --- a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TStringColumn.cs +++ b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TStringColumn.cs @@ -84,11 +84,9 @@ public TStringColumn DeepCopy() int offset = 0; offsetBuffer = new byte[(length + 1) * 4]; var memory = offsetBuffer.AsMemory(); - var typedMemory = Unsafe.As, Memory>(ref memory).Slice(0, length + 1); for(int _i188 = 0; _i188 < length; ++_i188) { - //typedMemory.Span[_i188] = offset; StreamExtensions.WriteInt32LittleEndian(offset, memory.Span, _i188 * 4); var size = await iprot.ReadI32Async(cancellationToken); @@ -109,7 +107,6 @@ public TStringColumn DeepCopy() await transport.ReadExactlyAsync(tmp.AsMemory(0, size), cancellationToken); values.Append(tmp.AsMemory(0, size).Span); } - //typedMemory.Span[length] = offset; StreamExtensions.WriteInt32LittleEndian(offset, memory.Span, length * 4); await iprot.ReadListEndAsync(cancellationToken); @@ -149,8 +146,8 @@ public TStringColumn DeepCopy() { throw new TProtocolException(TProtocolException.INVALID_DATA); } - - Values = new StringArray(length, new ArrowBuffer(offsetBuffer), values.Build(), new ArrowBuffer(nulls), BitUtility.CountBits(nulls)); + ArrowBuffer validityBitmapBuffer = BitmapUtilities.GetValidityBitmapBuffer(ref nulls, length, out int nullCount); + Values = new StringArray(length, new ArrowBuffer(offsetBuffer), values.Build(), validityBitmapBuffer, nullCount); } finally { diff --git a/csharp/src/Drivers/Apache/Thrift/StreamExtensions.cs b/csharp/src/Drivers/Apache/Thrift/StreamExtensions.cs index 4396518055..c252d9aff7 100644 --- a/csharp/src/Drivers/Apache/Thrift/StreamExtensions.cs +++ b/csharp/src/Drivers/Apache/Thrift/StreamExtensions.cs @@ -45,6 +45,22 @@ public static void WriteInt32LittleEndian(int value, Span buffer, int offs buffer[offset + 3] = (byte)(value >> 24); } + public static void ReverseEndianI64AtOffset(Span buffer, int offset) + { + // Check if the buffer is large enough to contain an i64 at the given offset + if (offset < 0 || buffer.Length < offset + sizeof(long)) + throw new ArgumentOutOfRangeException(nameof(offset), "Buffer is too small or offset is out of bounds."); + + // Swap the bytes to reverse the endianness of the i64 + byte temp; + for (int startIndex = offset, endIndex = offset + (sizeof(long) - 1); startIndex < endIndex; startIndex++, endIndex--) + { + temp = buffer[startIndex]; + buffer[startIndex] = buffer[endIndex]; + buffer[endIndex] = temp; + } + } + public static void ReverseEndianI32AtOffset(Span buffer, int offset) { // Check if the buffer is large enough to contain an i32 at the given offset @@ -64,6 +80,7 @@ public static void ReverseEndianI32AtOffset(Span buffer, int offset) buffer[offset + 1] = buffer[offset + 2]; buffer[offset + 2] = temp; } + public static void ReverseEndiannessInt16(Span buffer, int offset) { if (buffer == null) diff --git a/csharp/src/Drivers/Apache/Thrift/ThriftSocketTransport.cs b/csharp/src/Drivers/Apache/Thrift/ThriftSocketTransport.cs index e3acb34a15..b6950c4a81 100644 --- a/csharp/src/Drivers/Apache/Thrift/ThriftSocketTransport.cs +++ b/csharp/src/Drivers/Apache/Thrift/ThriftSocketTransport.cs @@ -34,6 +34,11 @@ public ThriftSocketTransport(string host, int port, TConfiguration config, int t { } + public ThriftSocketTransport(string hostNameOrIpAddress, int port, bool connectClient, TConfiguration config, int timeout = 0) + : base(hostNameOrIpAddress, port, connectClient, config, timeout) + { + } + public Stream Input { get { return this.InputStream; } } public Stream Output { get { return this.OutputStream; } } } diff --git a/csharp/src/Drivers/Apache/readme.md b/csharp/src/Drivers/Apache/readme.md index ec385f2e26..38d6160743 100644 --- a/csharp/src/Drivers/Apache/readme.md +++ b/csharp/src/Drivers/Apache/readme.md @@ -18,6 +18,7 @@ --> # Thrift-based Apache connectors + This library contains code for ADBC drivers built on top of the Thrift protocol with Arrow support: - Hive @@ -27,6 +28,7 @@ This library contains code for ADBC drivers built on top of the Thrift protocol Each driver is at a different state of implementation. ## Custom generation + Typically, [Thrift](https://thrift.apache.org/) code is generated from the Thrift compiler. And that is mostly true here as well. However, some files were further edited to include Arrow support. These contain the phrase `BUT THIS FILE HAS BEEN HAND EDITED TO SUPPORT ARROW SO REGENERATE AT YOUR OWN RISK` at the top. Some of these files include: ``` @@ -41,55 +43,26 @@ arrow-adbc/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/TStringColumn.cs ``` # Hive + The Hive classes serve as the base class for Spark and Impala, since both of those platform implement Hive capabilities. Core functionality of the Hive classes beyond the base library implementation is under development, has limited functionality, and may produce errors. # Impala + The Imapala classes are under development, have limited functionality, and may produce errors. # Spark -The Spark classes are intended for use against native Spark and Spark on Databricks. -## Spark Types - -The following table depicts how the Spark ADBC driver converts a Spark type to an Arrow type and a .NET type: - -| Spark Type | Arrow Type | C# Type | -| :--- | :---: | :---: | -| ARRAY* | String | string | -| BIGINT | Int64 | long | -| BINARY | Binary | byte[] | -| BOOLEAN | Boolean | bool | -| CHAR | String | string | -| DATE | Date32 | DateTime | -| DECIMAL | Decimal128 | SqlDecimal | -| DOUBLE | Double | double | -| FLOAT | Float | float | -| INT | Int32 | int | -| INTERVAL_DAY_TIME+ | String | string | -| INTERVAL_YEAR_MONTH+ | String | string | -| MAP* | String | string | -| NULL | Null | null | -| SMALLINT | Int16 | short | -| STRING | String | string | -| STRUCT* | String | string | -| TIMESTAMP | Timestamp | DateTimeOffset | -| TINYINT | Int8 | sbyte | -| UNION | String | string | -| USER_DEFINED | String | string | -| VARCHAR | String | string | - -\* Complex types are returned as strings
-\+ Interval types are returned as strings +The Spark classes are intended for use against native Spark and Spark on Databricks. +For more details, see [Spark Driver](Spark/README.md) ## Known Limitations 1. The API `SparkConnection.GetObjects` is not fully tested at this time 1. It may not return all catalogs and schema in the server. 1. It may throw an exception when returning object metadata from multiple catalog and schema. -1. API `Connection.GetTableSchema` does not return correct precision and scale for `NUMERIC`/`DECIMAL` types. 1. When a `NULL` value is returned for a `BINARY` type it is instead being returned as an empty array instead of the expected `null`. 1. Result set metadata does not provide information about the nullability of each column. They are marked as `nullable` by default, which may not be accurate. 1. The **Impala** driver is untested and is currently unsupported. diff --git a/csharp/src/Drivers/BigQuery/Apache.Arrow.Adbc.Drivers.BigQuery.csproj b/csharp/src/Drivers/BigQuery/Apache.Arrow.Adbc.Drivers.BigQuery.csproj index 48f95c28ff..0aaa4913cd 100644 --- a/csharp/src/Drivers/BigQuery/Apache.Arrow.Adbc.Drivers.BigQuery.csproj +++ b/csharp/src/Drivers/BigQuery/Apache.Arrow.Adbc.Drivers.BigQuery.csproj @@ -4,10 +4,10 @@ readme.md - - + + - + diff --git a/csharp/src/Drivers/BigQuery/BigQueryConnection.cs b/csharp/src/Drivers/BigQuery/BigQueryConnection.cs index 629ac5d0c2..398365314f 100644 --- a/csharp/src/Drivers/BigQuery/BigQueryConnection.cs +++ b/csharp/src/Drivers/BigQuery/BigQueryConnection.cs @@ -997,23 +997,20 @@ private IReadOnlyDictionary ParseOptions() { Dictionary options = new Dictionary(); - foreach (KeyValuePair keyValuePair in this.properties) + string[] statementOptions = new string[] { + BigQueryParameters.AllowLargeResults, + BigQueryParameters.UseLegacySQL, + BigQueryParameters.LargeDecimalsAsString, + BigQueryParameters.LargeResultsDestinationTable, + BigQueryParameters.GetQueryResultsOptionsTimeoutMinutes, + BigQueryParameters.MaxFetchConcurrency + }; + + foreach (string key in statementOptions) { - if (keyValuePair.Key == BigQueryParameters.AllowLargeResults) - { - options[keyValuePair.Key] = keyValuePair.Value; - } - if (keyValuePair.Key == BigQueryParameters.UseLegacySQL) - { - options[keyValuePair.Key] = keyValuePair.Value; - } - if (keyValuePair.Key == BigQueryParameters.LargeDecimalsAsString) - { - options[keyValuePair.Key] = keyValuePair.Value; - } - if (keyValuePair.Key == BigQueryParameters.LargeResultsDestinationTable) + if (properties.TryGetValue(key, out string? value)) { - options[keyValuePair.Key] = keyValuePair.Value; + options[key] = value; } } diff --git a/csharp/src/Drivers/BigQuery/BigQueryParameters.cs b/csharp/src/Drivers/BigQuery/BigQueryParameters.cs index 0f81291011..51272eb643 100644 --- a/csharp/src/Drivers/BigQuery/BigQueryParameters.cs +++ b/csharp/src/Drivers/BigQuery/BigQueryParameters.cs @@ -34,6 +34,8 @@ public class BigQueryParameters public const string LargeDecimalsAsString = "adbc.bigquery.large_decimals_as_string"; public const string Scopes = "adbc.bigquery.scopes"; public const string IncludeConstraintsWithGetObjects = "adbc.bigquery.include_constraints_getobjects"; + public const string GetQueryResultsOptionsTimeoutMinutes = "adbc.bigquery.get_query_results_options.timeout"; + public const string MaxFetchConcurrency = "adbc.bigquery.max_fetch_concurrency"; } /// diff --git a/csharp/src/Drivers/BigQuery/BigQueryStatement.cs b/csharp/src/Drivers/BigQuery/BigQueryStatement.cs index 297becff6d..ab7764b032 100644 --- a/csharp/src/Drivers/BigQuery/BigQueryStatement.cs +++ b/csharp/src/Drivers/BigQuery/BigQueryStatement.cs @@ -16,16 +16,14 @@ */ using System; -using System.Collections; using System.Collections.Generic; -using System.Data.SqlTypes; using System.IO; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Ipc; using Apache.Arrow.Types; +using Google.Api.Gax; using Google.Apis.Auth.OAuth2; using Google.Apis.Bigquery.v2.Data; using Google.Cloud.BigQuery.Storage.V1; @@ -55,16 +53,66 @@ public override QueryResult ExecuteQuery() { QueryOptions? queryOptions = ValidateOptions(); BigQueryJob job = this.client.CreateQueryJob(SqlQuery, null, queryOptions); - BigQueryResults results = job.GetQueryResults(); + + GetQueryResultsOptions getQueryResultsOptions = new GetQueryResultsOptions(); + + if (this.Options?.TryGetValue(BigQueryParameters.GetQueryResultsOptionsTimeoutMinutes, out string? timeoutMinutes) == true) + { + if (int.TryParse(timeoutMinutes, out int minutes)) + { + if (minutes >= 0) + { + getQueryResultsOptions.Timeout = TimeSpan.FromMinutes(minutes); + } + } + } + + BigQueryResults results = job.GetQueryResults(getQueryResultsOptions); BigQueryReadClientBuilder readClientBuilder = new BigQueryReadClientBuilder(); readClientBuilder.Credential = this.credential; BigQueryReadClient readClient = readClientBuilder.Build(); + if (results.TableReference == null) + { + // To get the results of all statements in a multi-statement query, enumerate the child jobs and call jobs.getQueryResults on each of them. + // Related public docs: https://cloud.google.com/bigquery/docs/multi-statement-queries#get_all_executed_statements + ListJobsOptions listJobsOptions = new ListJobsOptions(); + listJobsOptions.ParentJobId = results.JobReference.JobId; + PagedEnumerable joblist = client.ListJobs(listJobsOptions); + BigQueryJob firstQueryJob = new BigQueryJob(client, job.Resource); + foreach (BigQueryJob childJob in joblist) + { + var tempJob = client.GetJob(childJob.Reference.JobId); + var query = tempJob.Resource?.Configuration?.Query; + if (query != null && query.DestinationTable != null && query.DestinationTable.ProjectId != null && query.DestinationTable.DatasetId != null && query.DestinationTable.TableId != null) + { + firstQueryJob = tempJob; + } + } + results = firstQueryJob.GetQueryResults(); + } + + if (results.TableReference == null) + { + throw new AdbcException("There is no query statement"); + } + string table = $"projects/{results.TableReference.ProjectId}/datasets/{results.TableReference.DatasetId}/tables/{results.TableReference.TableId}"; + int maxStreamCount = 1; + if (this.Options?.TryGetValue(BigQueryParameters.MaxFetchConcurrency, out string? maxStreamCountString) == true) + { + if (int.TryParse(maxStreamCountString, out int count)) + { + if (count >= 0) + { + maxStreamCount = count; + } + } + } ReadSession rs = new ReadSession { Table = table, DataFormat = DataFormat.Arrow }; - ReadSession rrs = readClient.CreateReadSession("projects/" + results.TableReference.ProjectId, rs, 1); + ReadSession rrs = readClient.CreateReadSession("projects/" + results.TableReference.ProjectId, rs, maxStreamCount); long totalRows = results.TotalRows == null ? -1L : (long)results.TotalRows.Value; IArrowArrayStream stream = new MultiArrowReader(TranslateSchema(results.Schema), rrs.Streams.Select(s => ReadChunk(readClient, s.Name))); @@ -90,19 +138,6 @@ private Field TranslateField(TableFieldSchema field) return new Field(field.Name, TranslateType(field), field.Mode == "NULLABLE"); } - public override object? GetValue(IArrowArray arrowArray, int index) - { - switch (arrowArray) - { - case StructArray structArray: - return SerializeToJson(structArray, index); - case ListArray listArray: - return listArray.GetSlicedValues(index); - default: - return base.GetValue(arrowArray, index); - } - } - private IArrowType TranslateType(TableFieldSchema field) { // per https://developers.google.com/resources/api-libraries/documentation/bigquery/v2/java/latest/com/google/api/services/bigquery/model/TableFieldSchema.html#getType-- @@ -224,139 +259,6 @@ static IArrowReader ReadChunk(BigQueryReadClient readClient, string streamName) return options; } - private string SerializeToJson(StructArray structArray, int index) - { - Dictionary? jsonDictionary = ParseStructArray(structArray, index); - - return JsonSerializer.Serialize(jsonDictionary); - } - - private Dictionary? ParseStructArray(StructArray structArray, int index) - { - if (structArray.IsNull(index)) - return null; - - Dictionary jsonDictionary = new Dictionary(); - StructType structType = (StructType)structArray.Data.DataType; - for (int i = 0; i < structArray.Data.Children.Length; i++) - { - string name = structType.Fields[i].Name; - object? value = GetValue(structArray.Fields[i], index); - - if (value is StructArray structArray1) - { - List?> children = new List?>(); - - if (structArray1.Length > 1) - { - for (int j = 0; j < structArray1.Length; j++) - children.Add(ParseStructArray(structArray1, j)); - } - - if (children.Count > 0) - { - jsonDictionary.Add(name, children); - } - else - { - jsonDictionary.Add(name, ParseStructArray(structArray1, index)); - } - } - else if (value is IArrowArray arrowArray) - { - IList? values = CreateList(arrowArray); - - if (values != null) - { - for (int j = 0; j < arrowArray.Length; j++) - { - values.Add(base.GetValue(arrowArray, j)); - } - - jsonDictionary.Add(name, values); - } - else - { - jsonDictionary.Add(name, new List()); - } - } - else - { - jsonDictionary.Add(name, value); - } - } - - return jsonDictionary; - } - - private IList? CreateList(IArrowArray arrowArray) - { - if (arrowArray == null) throw new ArgumentNullException(nameof(arrowArray)); - - switch (arrowArray) - { - case BooleanArray booleanArray: - return new List(); - case Date32Array date32Array: - case Date64Array date64Array: - return new List(); - case Decimal128Array decimal128Array: - return new List(); - case Decimal256Array decimal256Array: - return new List(); - case DoubleArray doubleArray: - return new List(); - case FloatArray floatArray: - return new List(); -#if NET5_0_OR_GREATER - case PrimitiveArray halfFloatArray: - return new List(); -#endif - case Int8Array int8Array: - return new List(); - case Int16Array int16Array: - return new List(); - case Int32Array int32Array: - return new List(); - case Int64Array int64Array: - return new List(); - case StringArray stringArray: - return new List(); -#if NET6_0_OR_GREATER - case Time32Array time32Array: - case Time64Array time64Array: - return new List(); -#else - case Time32Array time32Array: - case Time64Array time64Array: - return new List(); -#endif - case TimestampArray timestampArray: - return new List(); - case UInt8Array uInt8Array: - return new List(); - case UInt16Array uInt16Array: - return new List(); - case UInt32Array uInt32Array: - return new List(); - case UInt64Array uInt64Array: - return new List(); - - case BinaryArray binaryArray: - return new List(); - - // not covered: - // -- struct array - // -- dictionary array - // -- fixed size binary - // -- list array - // -- union array - } - - return null; - } - - class MultiArrowReader : IArrowArrayStream { readonly Schema schema; diff --git a/csharp/src/Drivers/BigQuery/readme.md b/csharp/src/Drivers/BigQuery/readme.md index c9cad1909a..d78ab4c891 100644 --- a/csharp/src/Drivers/BigQuery/readme.md +++ b/csharp/src/Drivers/BigQuery/readme.md @@ -51,6 +51,12 @@ https://cloud.google.com/dotnet/docs/reference/Google.Cloud.BigQuery.V2/latest/G **adbc.bigquery.auth_json_credential**
    Required if using `service` authentication. This value is passed to the [GoogleCredential.FromJson](https://cloud.google.com/dotnet/docs/reference/Google.Apis/latest/Google.Apis.Auth.OAuth2.GoogleCredential#Google_Apis_Auth_OAuth2_GoogleCredential_FromJson_System_String) method. +**adbc.bigquery.get_query_results_options.timeout**
+    Optional. Sets the timeout (in minutes) for the GetQueryResultsOptions value. If not set, defaults to 5 minutes. + +**adbc.bigquery.max_fetch_concurrency**
+    Optional. Sets the [maxStreamCount](https://cloud.google.com/dotnet/docs/reference/Google.Cloud.BigQuery.Storage.V1/latest/Google.Cloud.BigQuery.Storage.V1.BigQueryReadClient#Google_Cloud_BigQuery_Storage_V1_BigQueryReadClient_CreateReadSession_System_String_Google_Cloud_BigQuery_Storage_V1_ReadSession_System_Int32_Google_Api_Gax_Grpc_CallSettings_) for the CreateReadSession method. If not set, defaults to 1. + **adbc.bigquery.include_constraints_getobjects**
    Optional. Some callers do not need the constraint details when they get the table information and can improve the speed of obtaining the results. Setting this value to `"false"` will not include the constraint details. The default value is `"true"`. diff --git a/csharp/src/Drivers/FlightSql/Apache.Arrow.Adbc.Drivers.FlightSql.csproj b/csharp/src/Drivers/FlightSql/Apache.Arrow.Adbc.Drivers.FlightSql.csproj index 64532c3fb8..ec718b6376 100644 --- a/csharp/src/Drivers/FlightSql/Apache.Arrow.Adbc.Drivers.FlightSql.csproj +++ b/csharp/src/Drivers/FlightSql/Apache.Arrow.Adbc.Drivers.FlightSql.csproj @@ -3,8 +3,8 @@ netstandard2.0;net6.0 - - + + diff --git a/csharp/src/Drivers/Interop/FlightSql/Apache.Arrow.Adbc.Drivers.Interop.FlightSql.csproj b/csharp/src/Drivers/Interop/FlightSql/Apache.Arrow.Adbc.Drivers.Interop.FlightSql.csproj new file mode 100644 index 0000000000..773062adba --- /dev/null +++ b/csharp/src/Drivers/Interop/FlightSql/Apache.Arrow.Adbc.Drivers.Interop.FlightSql.csproj @@ -0,0 +1,66 @@ + + + netstandard2.0;net472;net6.0 + readme.md + + false + + + + + + + + + + + + + + + + true + / + PreserveNewest + + + true + runtimes/win-x64/native + PreserveNewest + + + + + + + true + / + PreserveNewest + + + + + true + runtimes/win-x64/native + PreserveNewest + + + + + true + runtimes/linux-x64/native + PreserveNewest + + + + + true + runtimes/osx-x64/native + PreserveNewest + + + + + + + diff --git a/csharp/src/Drivers/Interop/FlightSql/Build-FlightSqlDriver.ps1 b/csharp/src/Drivers/Interop/FlightSql/Build-FlightSqlDriver.ps1 new file mode 100644 index 0000000000..fd7dee0690 --- /dev/null +++ b/csharp/src/Drivers/Interop/FlightSql/Build-FlightSqlDriver.ps1 @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +Write-Host "Building the Flight SQL ADBC Go driver" +Write-Host "IsPackagingPipeline=$Env:IsPackagingPipeline" + +if (-not (Test-Path env:IsPackagingPipeline)) { + Write-Host "IsPackagingPipeline environment variable does not exist." + exit +} + +# Get the value of the IsPackagingPipeline environment variable +$IsPackagingPipelineValue = $env:IsPackagingPipeline + +# Check if the value is "true" +if ($IsPackagingPipelineValue -ne "true") { + Write-Host "IsPackagingPipeline is not set to 'true'. Exiting the script." + exit +} + +$location = Get-Location + +$file = "libadbc_driver_flightsql.dll" + +if(Test-Path $file) +{ + exit +} + +cd ..\..\..\..\..\go\adbc\pkg + +make $file + +if(Test-Path $file) +{ + $processes = Get-Process | Where-Object { $_.Modules.ModuleName -contains $file } + + if ($processes.Count -eq 0) { + try { + # File is not being used, copy it to the destination + Copy-Item -Path $file -Destination $location + Write-Host "File copied successfully." + } + catch { + Write-Host "Caught error: $_" + } + } else { + Write-Host "File is being used by another process. Cannot copy." + } +} diff --git a/csharp/src/Drivers/Interop/FlightSql/FlightSqlDriverLoader.cs b/csharp/src/Drivers/Interop/FlightSql/FlightSqlDriverLoader.cs new file mode 100644 index 0000000000..12ea491d92 --- /dev/null +++ b/csharp/src/Drivers/Interop/FlightSql/FlightSqlDriverLoader.cs @@ -0,0 +1,39 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +namespace Apache.Arrow.Adbc.Drivers.Interop.FlightSql +{ + /// + /// Lightweight class for loading the Flight SQL Go driver to .NET. + /// + public class FlightSqlDriverLoader : AdbcDriverLoader + { + public FlightSqlDriverLoader() : base("libadbc_driver_flightsql", "FlightSqlDriverInit") + { + } + + /// + /// Loads the Snowflake Go driver from the current directory using the default name and entry point. + /// + /// An based on the Snowflake Go driver. + /// + public static AdbcDriver LoadDriver() + { + return new FlightSqlDriverLoader().FindAndLoadDriver(); + } + } +} diff --git a/csharp/src/Drivers/Interop/FlightSql/copyFlightSqlDriver.sh b/csharp/src/Drivers/Interop/FlightSql/copyFlightSqlDriver.sh new file mode 100644 index 0000000000..ea931fa6d5 --- /dev/null +++ b/csharp/src/Drivers/Interop/FlightSql/copyFlightSqlDriver.sh @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +### copies the Snowflake binaries for all platforms to be packaged for NuGet + +echo "Copying the FlightSQL ADBC Go driver" +echo "IsPackagingPipeline=$IsPackagingPipeline" + +if [[ -z "${IsPackagingPipeline}" ]]; then + echo "IsPackagingPipeline environment variable does not exist." + exit 0 +fi + +# Get the value of the IsPackagingPipeline environment variable +IsPackagingPipelineValue="${IsPackagingPipeline}" + +# Check if the value is "true" +if [[ "${IsPackagingPipelineValue}" != "true" ]]; then + echo "IsPackagingPipeline is not set to 'true'. Exiting the script." + exit 0 +fi + +destination_dir=$(pwd) + +file="libadbc_driver_flightsql.*" + +if ls libadbc_driver_flightsql.* 1> /dev/null 2>&1; then + echo "Files found. Exiting the script." + exit 0 +else + cd ../../../../../go/adbc/pkg + + source_dir=$(pwd) + + files_to_copy=$(find "$source_dir" -type f -name "$file") + + for file in $files_to_copy; do + cp "$file" "$destination_dir" + echo "Copied $file to $destination_dir" + done +fi diff --git a/csharp/src/Drivers/Interop/FlightSql/readme.md b/csharp/src/Drivers/Interop/FlightSql/readme.md new file mode 100644 index 0000000000..21e4c39e57 --- /dev/null +++ b/csharp/src/Drivers/Interop/FlightSql/readme.md @@ -0,0 +1,30 @@ + + +# About +This project generates a NuGet package containing the Flight SQL ADBC Go Driver for use in other .NET projects. It contains a lightweight loader for loading the driver. + +For details, see: + +[Flight SQL Driver](https://arrow.apache.org/adbc/main/driver/flight_sql.html) for docs + +[GitHub](https://github.com/apache/arrow-adbc/tree/main/go/adbc/driver/flightsql) for source code + +## Build the Flight SQL Driver +Run the `Build-FlightSqlDriver.ps1` script to build the Flight SQL driver. diff --git a/csharp/src/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Drivers.Interop.Snowflake.csproj b/csharp/src/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Drivers.Interop.Snowflake.csproj index cb810a99be..b71e9b6619 100644 --- a/csharp/src/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Drivers.Interop.Snowflake.csproj +++ b/csharp/src/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Drivers.Interop.Snowflake.csproj @@ -8,12 +8,12 @@ - + - + diff --git a/csharp/src/Drivers/Interop/Snowflake/SnowflakeDriverLoader.cs b/csharp/src/Drivers/Interop/Snowflake/SnowflakeDriverLoader.cs index 5229f64582..14d85b6de0 100644 --- a/csharp/src/Drivers/Interop/Snowflake/SnowflakeDriverLoader.cs +++ b/csharp/src/Drivers/Interop/Snowflake/SnowflakeDriverLoader.cs @@ -16,65 +16,27 @@ */ using System.IO; -using System.Runtime.InteropServices; -using Apache.Arrow.Adbc.C; namespace Apache.Arrow.Adbc.Drivers.Interop.Snowflake { /// /// Lightweight class for loading the Snowflake Go driver to .NET. /// - public class SnowflakeDriverLoader + public class SnowflakeDriverLoader : AdbcDriverLoader { - /// - /// Loads the Snowflake Go driver from the current directory using the default name and entry point. - /// - /// An based on the Snowflake Go driver. - /// - public static AdbcDriver LoadDriver() + public SnowflakeDriverLoader() : base("libadbc_driver_snowflake", "SnowflakeDriverInit") { - string root = "runtimes"; - string native = "native"; - string fileName = $"libadbc_driver_snowflake"; - string file; - - // matches extensions in https://github.com/apache/arrow-adbc/blob/main/go/adbc/pkg/Makefile - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - file = Path.Combine(root, $"linux-{GetArchitecture()}", native, $"{fileName}.so"); - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - file = Path.Combine(root, $"win-{GetArchitecture()}", native, $"{fileName}.dll"); - else - file = Path.Combine(root, $"osx-{GetArchitecture()}", native, $"{fileName}.dylib"); - if (File.Exists(file)) - { - // get the full path because some .NET versions need it - file = Path.GetFullPath(file); - } - else - { - throw new FileNotFoundException($"Cound not find {file}"); - } - - return LoadDriver(file, "SnowflakeDriverInit"); - } - - private static string GetArchitecture() - { - return RuntimeInformation.OSArchitecture.ToString().ToLower(); } /// /// Loads the Snowflake Go driver from the current directory using the default name and entry point. /// - /// The file to load. - /// The entry point of the file. /// An based on the Snowflake Go driver. - public static AdbcDriver LoadDriver(string file, string entryPoint) + /// + public static AdbcDriver LoadDriver() { - AdbcDriver snowflakeDriver = CAdbcDriverImporter.Load(file, entryPoint); - - return snowflakeDriver; + return new SnowflakeDriverLoader().FindAndLoadDriver(); } } } diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/AdbcTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/AdbcTests.cs index 9577666431..3b55387b1d 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/AdbcTests.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/AdbcTests.cs @@ -207,7 +207,7 @@ private string GetPathForAdbcH() { // find the adbc.h file from the repo - string path = Path.Combine(new string[] { "..", "..", "..", "..", "..", "adbc.h" }); + string path = Path.Combine(new string[] { "..", "..", "..", "..", "..", "c", "include", "arrow-adbc", "adbc.h"}); Assert.True(File.Exists(path)); diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj b/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj index 519b6c3f5b..3b6eac7e2d 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj @@ -9,15 +9,15 @@ - - - - - + + + + + all runtime; build; native; contentfiles; analyzers - + diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Client/ClientTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/Client/ClientTests.cs index 61efd0613b..e1d5a91781 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/Client/ClientTests.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Client/ClientTests.cs @@ -17,7 +17,9 @@ using System; using System.Collections.Generic; +using System.ComponentModel; using System.Data.SqlTypes; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Client; @@ -25,6 +27,7 @@ using Apache.Arrow.Types; using Moq; using Xunit; +using AdbcClient = Apache.Arrow.Adbc.Client; namespace Apache.Arrow.Adbc.Tests.Client { @@ -43,6 +46,56 @@ public void TestDecimalValues(DecimalBehavior decimalBehavior, string value, int Assert.True(rdrValue.GetType().Equals(expectedType)); } + /// + /// Demonstrates the OnGetValue method of an AdbcDataReader. + /// + /// True/False to treat integers as strings. + [Theory] + [InlineData(true)] + [InlineData(false)] + public void TestOnGetValue(bool treatIntegersAsStrings) + { + AdbcDataReader rdr = GetMoqDataReaderForIntegers(); + + if (treatIntegersAsStrings) + { + rdr.OnGetValue += (o, e) => + { + if (o != null) + { + Int32Array? ints = o as Int32Array; + + if (ints != null) + { + int? value = ints.GetValue(e); + + if (value.HasValue) + return value.Value.ToString(); + else + return string.Empty; + } + } + + return string.Empty; + }; + } + + while (rdr.Read()) + { + object? rdrValue = rdr.GetValue(0); + + if (treatIntegersAsStrings) + { + Assert.True(rdrValue.GetType().Equals(typeof(string))); + } + else + { + if (rdrValue != DBNull.Value) + Assert.True(rdrValue.GetType().Equals(typeof(int))); + } + } + } + private AdbcDataReader GetMoqDataReader(DecimalBehavior decimalBehavior, string value, int precision, int scale) { SqlDecimal sqlDecimal = SqlDecimal.Parse(value); @@ -63,7 +116,7 @@ private AdbcDataReader GetMoqDataReader(DecimalBehavior decimalBehavior, string List records = new List() { - new RecordBatch(schema, values, values.Count) + new RecordBatch(schema, values, array.Length) }; MockArrayStream mockArrayStream = new MockArrayStream(schema, records); @@ -71,7 +124,6 @@ private AdbcDataReader GetMoqDataReader(DecimalBehavior decimalBehavior, string Mock mockStatement = new Mock(); mockStatement.Setup(x => x.ExecuteQuery()).Returns(queryResult); ; - mockStatement.Setup(x => x.GetValue(It.IsAny(), It.IsAny())).Returns(sqlDecimal); Adbc.Client.AdbcConnection mockConnection = new Adbc.Client.AdbcConnection(); mockConnection.DecimalBehavior = decimalBehavior; @@ -81,6 +133,152 @@ private AdbcDataReader GetMoqDataReader(DecimalBehavior decimalBehavior, string AdbcDataReader reader = cmd.ExecuteReader(); return reader; } + + private AdbcDataReader GetMoqDataReaderForIntegers() + { + List> metadata = new List>(); + List fields = new List(); + fields.Add(new Field("TestIntegers", new Int32Type(), true, metadata)); + + Schema schema = new Schema(fields, metadata); + Int32Array.Builder numbersBuilder = new Int32Array.Builder(); + numbersBuilder.AppendRange(new List() { 1, 2, 3 }); + numbersBuilder.AppendNull(); //null for #4 + numbersBuilder.Append(5); + + Int32Array numbersArray = numbersBuilder.Build(); + + List values = new List() { numbersArray }; + + List records = new List() + { + new RecordBatch(schema, values, numbersArray.Length) + }; + + MockArrayStream mockArrayStream = new MockArrayStream(schema, records); + QueryResult queryResult = new QueryResult(1, mockArrayStream); + + Mock mockStatement = new Mock(); + mockStatement.Setup(x => x.ExecuteQuery()).Returns(queryResult); ; + + Adbc.Client.AdbcConnection mockConnection = new Adbc.Client.AdbcConnection(); + + AdbcCommand cmd = new AdbcCommand(mockStatement.Object, mockConnection); + + AdbcDataReader reader = cmd.ExecuteReader(); + return reader; + } + + [Theory] + [InlineData("(adbc.driver.value, 1, s)", "adbc.driver.value", 1, "s", true)] + [InlineData("(somevalue,10, ms)", "somevalue", 10, "ms", true)] + [InlineData("(somevalue,10, s)", "somevalue", 10, "s", true)] + [InlineData("somevalue,10, s)", null, null, null, false)] + [InlineData("(somevalue,10, s", null, null, null, false)] + [InlineData("(some.value_goes.here,99,Q)", null, null, null, false)] + [InlineData("some.value_goes.here,99,Q", null, null, null, false)] + public void TestTimeoutParsing(string value, string? driverPropertyName, int? timeout, string? unit, bool success) + { + if (!success) + { + try + { + ConnectionStringParser.ParseTimeoutValue(value); + } + catch (ArgumentOutOfRangeException) { } + catch (InvalidOperationException) { } + catch + { + Assert.Fail("Unknown exception found"); + } + } + else + { + Assert.True(driverPropertyName != null); + Assert.True(timeout != null); + Assert.True(unit != null); + + TimeoutValue timeoutValue = ConnectionStringParser.ParseTimeoutValue(value); + + Assert.Equal(driverPropertyName, timeoutValue.DriverPropertyName); + Assert.Equal(timeout, timeoutValue.Value); + Assert.Equal(unit, timeoutValue.Units); + } + } + + [Theory] + [ClassData(typeof(ConnectionParsingTestData))] + internal void TestConnectionStringParsing(ConnectionStringExample connectionStringExample) + { + AdbcClient.AdbcConnection cn = new AdbcClient.AdbcConnection(connectionStringExample.ConnectionString); + + Mock mockStatement = new Mock(); + AdbcCommand cmd = new AdbcCommand(mockStatement.Object, cn); + + Assert.True(cn.StructBehavior == connectionStringExample.ExpectedStructBehavior); + Assert.True(cn.DecimalBehavior == connectionStringExample.ExpectedDecimalBehavior); + Assert.True(cn.ConnectionTimeout == connectionStringExample.ConnectionTimeout); + + if (!string.IsNullOrEmpty(connectionStringExample.CommandTimeoutProperty)) + { + Assert.True(cmd.AdbcCommandTimeoutProperty == connectionStringExample.CommandTimeoutProperty); + Assert.True(cmd.CommandTimeout == connectionStringExample.CommandTimeout); + } + else + { + Assert.Throws(() => cmd.AdbcCommandTimeoutProperty); + } + } + } + + internal class ConnectionStringExample + { + public ConnectionStringExample( + string connectionString, + DecimalBehavior decimalBehavior, + StructBehavior structBehavior, + string connectionTimeoutPropertyName, + int connectionTimeout, + string commandTimeoutPropertyName, + int commandTimeout) + { + ConnectionString = connectionString; + ExpectedDecimalBehavior = decimalBehavior; + ExpectedStructBehavior = structBehavior; + ConnectionTimeoutProperty = connectionTimeoutPropertyName; + ConnectionTimeout = connectionTimeout; + CommandTimeoutProperty = commandTimeoutPropertyName; + CommandTimeout = commandTimeout; + } + + public string ConnectionString { get; } + + public string ConnectionTimeoutProperty { get; } + + public int ConnectionTimeout { get; } + + public DecimalBehavior ExpectedDecimalBehavior { get; } + + public StructBehavior ExpectedStructBehavior { get; } + + public string CommandTimeoutProperty { get; } + + public int CommandTimeout { get; } + } + + /// + /// Collection of for testing statement timeouts."/> + /// + internal class ConnectionParsingTestData : TheoryData + { + public ConnectionParsingTestData() + { + int defaultDbConnectionTimeout = 15; + + Add(new("StructBehavior=JsonString", default, StructBehavior.JsonString, "", defaultDbConnectionTimeout, "", 30)); + Add(new("StructBehavior=JsonString;AdbcCommandTimeout=(adbc.apache.statement.query_timeout_s,45,s)", default, StructBehavior.JsonString, "", defaultDbConnectionTimeout, "adbc.apache.statement.query_timeout_s", 45)); + Add(new("StructBehavior=JsonString;DecimalBehavior=OverflowDecimalAsString;AdbcConnectionTimeout=(adbc.spark.connect_timeout_ms,90,s);AdbcCommandTimeout=(adbc.apache.statement.query_timeout_s,45,s)", DecimalBehavior.OverflowDecimalAsString, StructBehavior.JsonString, "adbc.spark.connect_timeout_ms", 90, "adbc.apache.statement.query_timeout_s", 45)); + } } class MockArrayStream : IArrowArrayStream diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs index 1d947062dd..3deac62831 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs @@ -40,10 +40,12 @@ public class ClientTests /// The to use /// The queries to run /// The expected results (one per query) - public static void CanClientExecuteUpdate(Adbc.Client.AdbcConnection adbcConnection, + public static void CanClientExecuteUpdate( + Adbc.Client.AdbcConnection adbcConnection, TestConfiguration testConfiguration, string[] queries, - List expectedResults) + IReadOnlyList expectedResults, + string? environmentName = null) { if (adbcConnection == null) throw new ArgumentNullException(nameof(adbcConnection)); if (testConfiguration == null) throw new ArgumentNullException(nameof(testConfiguration)); @@ -63,7 +65,7 @@ public static void CanClientExecuteUpdate(Adbc.Client.AdbcConnection adbcConnect int rows = adbcCommand.ExecuteNonQuery(); - Assert.Equal(expectedResults[i], rows); + Assert.True(expectedResults[i]==rows, Utils.FormatMessage("Expected results are not equal", environmentName)); } } @@ -72,20 +74,27 @@ public static void CanClientExecuteUpdate(Adbc.Client.AdbcConnection adbcConnect /// /// The to use. /// The to use - public static void CanClientGetSchema(Adbc.Client.AdbcConnection adbcConnection, TestConfiguration testConfiguration) + /// The custom query to use instead of query from "/> + /// The custom column count to use instead of query from + public static void CanClientGetSchema( + Adbc.Client.AdbcConnection adbcConnection, + TestConfiguration testConfiguration, + string? customQuery = default, + int? expectedColumnCount = default, + string? environmentName = null) { if (adbcConnection == null) throw new ArgumentNullException(nameof(adbcConnection)); if (testConfiguration == null) throw new ArgumentNullException(nameof(testConfiguration)); adbcConnection.Open(); - using AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); + using AdbcCommand adbcCommand = new AdbcCommand(customQuery ?? testConfiguration.Query, adbcConnection); using AdbcDataReader reader = adbcCommand.ExecuteReader(CommandBehavior.SchemaOnly); DataTable? table = reader.GetSchemaTable(); // there is one row per field - Assert.Equal(testConfiguration.Metadata.ExpectedColumnCount, table?.Rows.Count); + Assert.Equal(expectedColumnCount ?? testConfiguration.Metadata.ExpectedColumnCount, table?.Rows.Count); } /// @@ -94,7 +103,14 @@ public static void CanClientGetSchema(Adbc.Client.AdbcConnection adbcConnection, /// /// The to use. /// The to use - public static void CanClientExecuteQuery(Adbc.Client.AdbcConnection adbcConnection, TestConfiguration testConfiguration) + /// Allows additional options to be set on the command before execution + public static void CanClientExecuteQuery( + Adbc.Client.AdbcConnection adbcConnection, + TestConfiguration testConfiguration, + Action? additionalCommandOptionsSetter = null, + string? customQuery = default, + int? expectedResultsCount = default, + string? environmentName = null) { if (adbcConnection == null) throw new ArgumentNullException(nameof(adbcConnection)); if (testConfiguration == null) throw new ArgumentNullException(nameof(testConfiguration)); @@ -103,7 +119,8 @@ public static void CanClientExecuteQuery(Adbc.Client.AdbcConnection adbcConnecti adbcConnection.Open(); - using AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); + using AdbcCommand adbcCommand = new AdbcCommand(customQuery ?? testConfiguration.Query, adbcConnection); + additionalCommandOptionsSetter?.Invoke(adbcCommand); using AdbcDataReader reader = adbcCommand.ExecuteReader(); try @@ -126,7 +143,7 @@ public static void CanClientExecuteQuery(Adbc.Client.AdbcConnection adbcConnecti } finally { reader.Close(); } - Assert.Equal(testConfiguration.ExpectedResultsCount, count); + Assert.Equal(expectedResultsCount ?? testConfiguration.ExpectedResultsCount, count); } /// @@ -135,7 +152,10 @@ public static void CanClientExecuteQuery(Adbc.Client.AdbcConnection adbcConnecti /// /// The to use. /// The to use - public static void VerifyTypesAndValues(Adbc.Client.AdbcConnection adbcConnection, SampleDataBuilder sampleDataBuilder) + public static void VerifyTypesAndValues( + Adbc.Client.AdbcConnection adbcConnection, + SampleDataBuilder sampleDataBuilder, + string? environmentName = null) { if (adbcConnection == null) throw new ArgumentNullException(nameof(adbcConnection)); if (sampleDataBuilder == null) throw new ArgumentNullException(nameof(sampleDataBuilder)); @@ -144,6 +164,13 @@ public static void VerifyTypesAndValues(Adbc.Client.AdbcConnection adbcConnectio foreach (SampleData sample in sampleDataBuilder.Samples) { + foreach (string preQueryCommandText in sample.PreQueryCommands) + { + using AdbcCommand preQueryCommand = adbcConnection.CreateCommand(); + preQueryCommand.CommandText = preQueryCommandText; + preQueryCommand.ExecuteNonQuery(); + } + using AdbcCommand dbCommand = adbcConnection.CreateCommand(); dbCommand.CommandText = sample.Query; @@ -152,18 +179,25 @@ public static void VerifyTypesAndValues(Adbc.Client.AdbcConnection adbcConnectio { var column_schema = reader.GetColumnSchema(); DataTable? dataTable = reader.GetSchemaTable(); - Assert.NotNull(dataTable); + Assert.True(dataTable != null, Utils.FormatMessage("dataTable is null", environmentName) ); - Assert.True(reader.FieldCount == sample.ExpectedValues.Count, $"{sample.ExpectedValues.Count} fields were expected but {reader.FieldCount} fields were returned for the query [{sample.Query}]"); + Assert.True(reader.FieldCount == sample.ExpectedValues.Count, Utils.FormatMessage($"{sample.ExpectedValues.Count} fields were expected but {reader.FieldCount} fields were returned for the query [{sample.Query}]", environmentName)); for (int i = 0; i < reader.FieldCount; i++) { object value = reader.GetValue(i); ColumnNetTypeArrowTypeValue ctv = sample.ExpectedValues[i]; - AssertTypeAndValue(ctv, value, reader, column_schema, dataTable, sample.Query); + AssertTypeAndValue(ctv, value, reader, column_schema, dataTable, sample.Query, environmentName); } } + + foreach (string postQueryCommandText in sample.PostQueryCommands) + { + using AdbcCommand preQueryCommand = adbcConnection.CreateCommand(); + preQueryCommand.CommandText = postQueryCommandText; + preQueryCommand.ExecuteNonQuery(); + } } } @@ -181,7 +215,8 @@ static void AssertTypeAndValue( DbDataReader reader, ReadOnlyCollection column_schema, DataTable dataTable, - string query) + string query, + string? environmentName = null) { string name = ctv.Name; Type? clientArrowType = column_schema.Where(x => x.ColumnName == name).FirstOrDefault()?.DataType; @@ -201,18 +236,18 @@ static void AssertTypeAndValue( Type? netType = reader[name]?.GetType(); if (netType == typeof(DBNull)) netType = null; - Assert.True(clientArrowType == ctv.ExpectedNetType, $"{name} is {clientArrowType.Name} and not {ctv.ExpectedNetType.Name} in the column schema for query [{query}]"); + Assert.True(clientArrowType == ctv.ExpectedNetType, Utils.FormatMessage($"{name} is {clientArrowType.Name} and not {ctv.ExpectedNetType.Name} in the column schema for query [{query}]", environmentName)); - Assert.True(dataTableType == ctv.ExpectedNetType, $"{name} is {dataTableType.Name} and not {ctv.ExpectedNetType.Name} in the data table for query [{query}]"); + Assert.True(dataTableType == ctv.ExpectedNetType, Utils.FormatMessage($"{name} is {dataTableType.Name} and not {ctv.ExpectedNetType.Name} in the data table for query [{query}]", environmentName)); if (arrowType is null) - Assert.True(ctv.ExpectedArrowArrayType is null, $"{name} is null and not {ctv.ExpectedArrowArrayType!.Name} in the provider type for query [{query}]"); + Assert.True(ctv.ExpectedArrowArrayType is null, Utils.FormatMessage($"{name} is null and not {ctv.ExpectedArrowArrayType!.Name} in the provider type for query [{query}]", environmentName)); else - Assert.True(arrowType.GetType() == ctv.ExpectedArrowArrayType, $"{name} is {arrowType.Name} and not {ctv.ExpectedArrowArrayType.Name} in the provider type for query [{query}]"); + Assert.True(arrowType.GetType() == ctv.ExpectedArrowArrayType, Utils.FormatMessage($"{name} is {arrowType.Name} and not {ctv.ExpectedArrowArrayType.Name} in the provider type for query [{query}]", environmentName)); if (netType != null) { - Assert.True(netType == ctv.ExpectedNetType, $"{name} is {netType.Name} and not {ctv.ExpectedNetType.Name} in the reader for query [{query}]"); + Assert.True(netType == ctv.ExpectedNetType, Utils.FormatMessage($"{name} is {netType.Name} and not {ctv.ExpectedNetType.Name} in the reader for query [{query}]", environmentName)); } if (value != DBNull.Value) @@ -220,20 +255,20 @@ static void AssertTypeAndValue( var type = value.GetType(); if (type.BaseType?.Name.Contains("PrimitiveArray") == false) { - Assert.True(ctv.ExpectedNetType == type, $"Expected type does not match actual type for {ctv.Name} for query [{query}]"); + Assert.True(ctv.ExpectedNetType == type, Utils.FormatMessage($"Expected type does not match actual type for {ctv.Name} for query [{query}]", environmentName)); if (value is byte[] actualBytes) { byte[]? expectedBytes = ctv.ExpectedValue as byte[]; - Assert.True(expectedBytes != null && actualBytes.SequenceEqual(expectedBytes), $"byte[] values do not match expected values for {ctv.Name} for query [{query}]"); + Assert.True(expectedBytes != null && actualBytes.SequenceEqual(expectedBytes), Utils.FormatMessage($"byte[] values do not match expected values for {ctv.Name} for query [{query}]", environmentName)); } else if (ctv.ExpectedValue is null) { - Assert.True(value is null, $"Expected value [{ctv.ExpectedValue}] does not match actual value [{value}] for {ctv.Name} for query [{query}]"); + Assert.True(value is null, Utils.FormatMessage($"Expected value [{ctv.ExpectedValue}] does not match actual value [{value}] for {ctv.Name} for query [{query}]", environmentName)); } else { - Assert.True(ctv.ExpectedValue.Equals(value), $"Expected value [{ctv.ExpectedValue}] does not match actual value [{value}] for {ctv.Name} for query [{query}]"); + Assert.True(ctv.ExpectedValue.Equals(value), Utils.FormatMessage($"Expected value [{ctv.ExpectedValue}] does not match actual value [{value}] for {ctv.Name} for query [{query}]", environmentName)); } } else @@ -258,7 +293,7 @@ static void AssertTypeAndValue( if (i == j) { - Assert.True(expected.Equals(actual), $"Expected value does not match actual value for {ctv.Name} at {i} for query [{query}]"); + Assert.True(expected.Equals(actual), Utils.FormatMessage($"Expected value does not match actual value for {ctv.Name} at {i} for query [{query}]", environmentName)); } } } @@ -266,7 +301,7 @@ static void AssertTypeAndValue( } else { - Assert.True(ctv.ExpectedValue == null, $"The value for {ctv.Name} is null and but it's expected value is not null for query [{query}]"); + Assert.True(ctv.ExpectedValue == null, Utils.FormatMessage($"The value for {ctv.Name} is null and but it's expected value is not null for query [{query}]", environmentName)); } } } diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs index b6205aea17..beb9a35e7a 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs @@ -35,7 +35,13 @@ public class DriverTests /// /// The number of records. /// - public static void CanExecuteQuery(QueryResult queryResult, long expectedNumberOfResults) + /// + /// Name of the test environment. + /// + public static void CanExecuteQuery( + QueryResult queryResult, + long expectedNumberOfResults, + string? environmentName = null) { long count = 0; @@ -46,14 +52,14 @@ public static void CanExecuteQuery(QueryResult queryResult, long expectedNumberO count += nextBatch.Length; } - Assert.True(expectedNumberOfResults == count, $"The parsed records ({count}) differ from the expected amount ({expectedNumberOfResults})"); + Assert.True(expectedNumberOfResults == count, Utils.FormatMessage($"The parsed records ({count}) differ from the expected amount ({expectedNumberOfResults})", environmentName)); // if the values were set, make sure they are correct if (queryResult.RowCount != -1) { - Assert.True(queryResult.RowCount == expectedNumberOfResults, "The RowCount value does not match the expected results"); + Assert.True(queryResult.RowCount == expectedNumberOfResults, Utils.FormatMessage("The RowCount value does not match the expected results", environmentName)); - Assert.True(queryResult.RowCount == count, "The RowCount value does not match the counted records"); + Assert.True(queryResult.RowCount == count, Utils.FormatMessage("The RowCount value does not match the counted records", environmentName)); } } @@ -67,25 +73,35 @@ public static void CanExecuteQuery(QueryResult queryResult, long expectedNumberO /// /// The number of records. /// - public static async Task CanExecuteQueryAsync(QueryResult queryResult, long expectedNumberOfResults) + /// + /// Name of the test environment. + /// + public static async Task CanExecuteQueryAsync( + QueryResult queryResult, + long expectedNumberOfResults, + string? environmentName = null) { long count = 0; while (queryResult.Stream != null) { RecordBatch nextBatch = await queryResult.Stream.ReadNextRecordBatchAsync(); + if (expectedNumberOfResults == 0) + { + Assert.Null(nextBatch); + } if (nextBatch == null) { break; } count += nextBatch.Length; } - Assert.True(expectedNumberOfResults == count, $"The parsed records ({count}) differ from the expected amount ({expectedNumberOfResults})"); + Assert.True(expectedNumberOfResults == count, Utils.FormatMessage($"The parsed records ({count}) differ from the expected amount ({expectedNumberOfResults})", environmentName)); // if the values were set, make sure they are correct if (queryResult.RowCount != -1) { - Assert.True(queryResult.RowCount == expectedNumberOfResults, "The RowCount value does not match the expected results"); + Assert.True(queryResult.RowCount == expectedNumberOfResults, Utils.FormatMessage("The RowCount value does not match the expected results", environmentName)); - Assert.True(queryResult.RowCount == count, "The RowCount value does not match the counted records"); + Assert.True(queryResult.RowCount == count, Utils.FormatMessage("The RowCount value does not match the counted records", environmentName)); } } } diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/GetObjectsParser.cs b/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/GetObjectsParser.cs index 88c4dac2d0..a4e1456dc5 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/GetObjectsParser.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Metadata/GetObjectsParser.cs @@ -30,10 +30,9 @@ public class GetObjectsParser /// Parses a from a GetObjects call for the . /// /// - /// /// /// - public static List ParseCatalog(RecordBatch recordBatch, string? databaseName, string? schemaName) + public static List ParseCatalog(RecordBatch recordBatch, string? schemaName) { StringArray catalogNameArray = (StringArray)recordBatch.Column("catalog_name"); ListArray dbSchemaArray = (ListArray)recordBatch.Column("catalog_db_schemas"); diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/SampleDataBuilder.cs b/csharp/test/Apache.Arrow.Adbc.Tests/SampleDataBuilder.cs index 74baf16b83..08375ed425 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/SampleDataBuilder.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/SampleDataBuilder.cs @@ -47,6 +47,17 @@ public SampleData() /// public string Query { get; set; } = string.Empty; + /// + /// A set of commands to run before the query. + /// + public List PreQueryCommands { get; set; } = new List(); + + /// + /// A set of commands to run after the query. + /// + public List PostQueryCommands { get; set; } = new List(); + + /// /// The expected values. /// diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs b/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs index 92271678f0..8bbd2df116 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs @@ -32,21 +32,30 @@ namespace Apache.Arrow.Adbc.Tests /// /// Provides a base class for ADBC tests. /// - /// A TestConfiguration type to use when accessing test configuration files. - public abstract class TestBase : IDisposable where T : TestConfiguration + /// A TestConfiguration type to use when accessing test configuration files. + public abstract class TestBase : IDisposable + where TConfig : TestConfiguration + where TEnv : TestEnvironment { private bool _disposedValue; - private T? _testConfiguration; - private AdbcConnection? _connection = null; - private AdbcStatement? _statement = null; + private readonly Lazy _testConfiguration; + private readonly Lazy _connection; + private readonly Lazy _statement; + private readonly TestEnvironment.Factory _testEnvFactory; + private readonly Lazy _testEnvironment; /// /// Constructs a new TestBase object with an output helper. /// /// Test output helper for writing test output. - public TestBase(ITestOutputHelper? outputHelper) + public TestBase(ITestOutputHelper? outputHelper, TestEnvironment.Factory testEnvFactory) { OutputHelper = outputHelper; + _testEnvFactory = testEnvFactory; + _testEnvironment = new Lazy(() => _testEnvFactory.Create(() => Connection)); + _testConfiguration = new Lazy(() => Utils.LoadTestConfiguration(TestConfigVariable)); + _connection = new Lazy(() => NewConnection()); + _statement = new Lazy(() => Connection.CreateStatement()); } /// @@ -54,10 +63,17 @@ public TestBase(ITestOutputHelper? outputHelper) /// protected ITestOutputHelper? OutputHelper { get; } + public TEnv TestEnvironment => _testEnvironment.Value; + /// /// The name of the environment variable that stores the full location of the test configuration file. /// - protected abstract string TestConfigVariable { get; } + protected string TestConfigVariable => TestEnvironment.TestConfigVariable; + + protected string VendorVersion => TestEnvironment.VendorVersion; + + protected Version VendorVersionAsVersion => new Lazy(() => new Version(VendorVersion)).Value; + /// /// Creates a temporary table (if possible) using the native SQL dialect. @@ -65,10 +81,10 @@ public TestBase(ITestOutputHelper? outputHelper) /// The ADBC statement to apply the update. /// The columns definition in the native SQL dialect. /// A disposable temporary table object that will drop the table when disposed. - protected virtual async ValueTask NewTemporaryTableAsync(AdbcStatement statement, string columns) + protected async Task NewTemporaryTableAsync(AdbcStatement statement, string columns) { string tableName = NewTableName(); - string sqlUpdate = string.Format("CREATE TEMPORARY IF NOT EXISTS TABLE {0} ({1})", tableName, columns); + string sqlUpdate = TestEnvironment.GetCreateTemporaryTableStatement(tableName, columns); return await TemporaryTable.NewTemporaryTableAsync(statement, tableName, sqlUpdate); } @@ -76,74 +92,52 @@ protected virtual async ValueTask NewTemporaryTableAsync(AdbcSta /// Creates a new unique table name . /// /// A unique table name. - protected virtual string NewTableName() => string.Format( - "{0}.{1}.{2}", - TestConfiguration.Metadata.Catalog, - TestConfiguration.Metadata.Schema, - Guid.NewGuid().ToString().Replace("-", "") + protected string NewTableName() => string.Format( + "{0}{1}{2}", + string.IsNullOrEmpty(TestConfiguration.Metadata.Catalog) ? string.Empty : DelimitIdentifier(TestConfiguration.Metadata.Catalog) + ".", + string.IsNullOrEmpty(TestConfiguration.Metadata.Schema) ? string.Empty : DelimitIdentifier(TestConfiguration.Metadata.Schema) + ".", + DelimitIdentifier(Guid.NewGuid().ToString().Replace("-", "")) ); /// /// Gets the relative resource location of source SQL data used in driver testing. /// - protected abstract string SqlDataResourceLocation { get; } + protected string SqlDataResourceLocation => TestEnvironment.SqlDataResourceLocation; + + protected int ExpectedColumnCount => TestEnvironment.ExpectedColumnCount; /// /// Creates a new driver. /// - protected abstract AdbcDriver NewDriver { get; } + protected AdbcDriver NewDriver => TestEnvironment.CreateNewDriver(); + + protected bool SupportsDelete => TestEnvironment.SupportsDelete; + + protected bool SupportsUpdate => TestEnvironment.SupportsUpdate; + + protected bool ValidateAffectedRows => TestEnvironment.ValidateAffectedRows; /// /// Gets the parameters from the test configuration that are passed to the driver as a dictionary. /// /// The test configuration as input. /// Ditionary of parameters for the driver. - protected abstract Dictionary GetDriverParameters(T testConfiguration); + protected virtual Dictionary GetDriverParameters(TConfig testConfiguration) => TestEnvironment.GetDriverParameters(testConfiguration); /// /// Gets a single ADBC Connection for the object. /// - protected AdbcConnection Connection - { - get - { - if (_connection == null) - { - _connection = NewConnection(); - } - return _connection; - } - } + protected AdbcConnection Connection => _connection.Value; /// /// Gets as single ADBC Statement for the object. /// - protected AdbcStatement Statement - { - get - { - if (_statement == null) - { - _statement = Connection.CreateStatement(); - } - return _statement; - } - } + protected AdbcStatement Statement => _statement.Value; /// /// Gets the test configuration file. /// - protected T TestConfiguration - { - get - { - if (_testConfiguration == null) - { - _testConfiguration = Utils.LoadTestConfiguration(TestConfigVariable); - } - return _testConfiguration; - } - } + protected TConfig TestConfiguration => _testConfiguration.Value; /// /// Parses the queries from internal resource location @@ -159,7 +153,10 @@ protected string[] GetQueries() if (line.TrimStart().StartsWith("--")) { continue; } if (line.Contains(placeholder)) { - string modifiedLine = line.Replace(placeholder, $"{TestConfiguration.Metadata.Catalog}.{TestConfiguration.Metadata.Schema}.{TestConfiguration.Metadata.Table}"); + string table = TestConfiguration.Metadata.Table; + string catlog = !string.IsNullOrEmpty(TestConfiguration.Metadata.Catalog) ? TestConfiguration.Metadata.Catalog + "." : string.Empty; + string schema = !string.IsNullOrEmpty(TestConfiguration.Metadata.Schema) ? TestConfiguration.Metadata.Schema + "." : string.Empty; + string modifiedLine = line.Replace(placeholder, $"{catlog}{schema}{table}"); content.AppendLine(modifiedLine); } else @@ -173,13 +170,18 @@ protected string[] GetQueries() return queries; } + protected SampleDataBuilder GetSampleDataBuilder() + { + return TestEnvironment.GetSampleDataBuilder(); + } + /// /// Gets a the Spark ADBC driver with settings from the . /// /// /// /// - protected AdbcConnection NewConnection(T? testConfiguration = null, IReadOnlyDictionary? connectionOptions = null) + protected AdbcConnection NewConnection(TConfig? testConfiguration = null, IReadOnlyDictionary? connectionOptions = null) { Dictionary parameters = GetDriverParameters(testConfiguration ?? TestConfiguration); AdbcDatabase database = NewDriver.Open(parameters); @@ -200,7 +202,7 @@ protected async Task ValidateInsertSelectDeleteSingleValueAsync(string selectSta await InsertSingleValueAsync(tableName, columnName, formattedValue ?? value?.ToString()); await SelectAndValidateValuesAsync(selectStatement, value, 1); string whereClause = GetWhereClause(columnName, formattedValue ?? value); - await DeleteFromTableAsync(tableName, whereClause, 1); + if (SupportsDelete) await DeleteFromTableAsync(tableName, whereClause, 1); } /// @@ -216,22 +218,89 @@ protected async Task ValidateInsertSelectDeleteSingleValueAsync(string tableName await InsertSingleValueAsync(tableName, columnName, formattedValue ?? value?.ToString()); await SelectAndValidateValuesAsync(tableName, columnName, value, 1, formattedValue); string whereClause = GetWhereClause(columnName, formattedValue ?? value); - await DeleteFromTableAsync(tableName, whereClause, 1); + if (SupportsDelete) await DeleteFromTableAsync(tableName, whereClause, 1); + } + + /// + /// Validates that two inserts, select and delete statement works with the given value. + /// + /// The name of the table to use. + /// The name of the column. + /// The value to insert, select and delete. + /// The formated value to insert, select and delete. + /// + protected async Task ValidateInsertSelectDeleteTwoValuesAsync(string tableName, string columnName, object? value, string? formattedValue = null) + { + await InsertSingleValueAsync(tableName, columnName, formattedValue ?? value?.ToString()); + await InsertSingleValueAsync(tableName, columnName, formattedValue ?? value?.ToString()); + await SelectAndValidateValuesAsync(tableName, columnName, [value, value], 2, formattedValue); + string whereClause = GetWhereClause(columnName, formattedValue ?? value); + if (SupportsDelete) await DeleteFromTableAsync(tableName, whereClause, 2); } + /// + /// Validates "multi-value" scenarios + /// + /// + /// + /// + /// + /// + /// + /// + protected async Task ValidateInsertSelectDeleteMultipleValuesAsync(string tableName, string columnName, string indexColumnName, object?[] values, string?[]? formattedValues = null) + { + await InsertMultipleValuesWithIndexColumnAsync(tableName, columnName, indexColumnName, values, formattedValues); + + string selectStatement = $"SELECT {columnName}, {indexColumnName} FROM {tableName}"; + await SelectAndValidateValuesAsync(selectStatement, values, values.Length, hasIndexColumn: true); + + if (SupportsDelete) await DeleteFromTableAsync(tableName, "", values.Length); + } + + /// + /// Inserts multiple values + /// + /// Name of the table to perform the insert + /// Name of the value column to insert + /// Name of the index column to insert index values + /// The array of values to insert + /// The array of formatted values to insert + /// + protected async Task InsertMultipleValuesWithIndexColumnAsync(string tableName, string columnName, string indexColumnName, object?[] values, string?[]? formattedValues) + { + string insertStatement = GetInsertStatementWithIndexColumn(tableName, columnName, indexColumnName, values, formattedValues); + OutputHelper?.WriteLine(insertStatement); + Statement.SqlQuery = insertStatement; + UpdateResult updateResult = await Statement.ExecuteUpdateAsync(); + if (ValidateAffectedRows) Assert.Equal(values.Length, updateResult.AffectedRows); + } + + /// + /// Gets the SQL INSERT statement for inserting multiple values with an index column + /// + /// Name of the table to perform the insert + /// Name of the value column to insert + /// Name of the index column to insert index values + /// The array of values to insert + /// The array of formatted values to insert + /// + protected string GetInsertStatementWithIndexColumn(string tableName, string columnName, string indexColumnName, object?[] values, string?[]? formattedValues) => + TestEnvironment.GetInsertStatementWithIndexColumn(tableName, columnName, indexColumnName, values, formattedValues); + /// /// Inserts a single value into a table. /// /// The name of the table to use. /// The name of the column. /// The value to insert. - protected virtual async Task InsertSingleValueAsync(string tableName, string columnName, string? value) + protected async Task InsertSingleValueAsync(string tableName, string columnName, string? value) { - string insertNumberStatement = GetInsertValueStatement(tableName, columnName, value); - OutputHelper?.WriteLine(insertNumberStatement); - Statement.SqlQuery = insertNumberStatement; + string insertStatement = GetInsertStatement(tableName, columnName, value); + OutputHelper?.WriteLine(insertStatement); + Statement.SqlQuery = insertStatement; UpdateResult updateResult = await Statement.ExecuteUpdateAsync(); - Assert.Equal(1, updateResult.AffectedRows); + if (ValidateAffectedRows) Assert.Equal(1, updateResult.AffectedRows); } /// @@ -241,8 +310,8 @@ protected virtual async Task InsertSingleValueAsync(string tableName, string col /// The name of the column. /// The value to insert. /// - protected virtual string GetInsertValueStatement(string tableName, string columnName, string? value) => - string.Format("INSERT INTO {0} ({1}) VALUES ({2});", tableName, columnName, value ?? "NULL"); + protected string GetInsertStatement(string tableName, string columnName, string? value) => + TestEnvironment.GetInsertStatement(tableName, columnName, value); /// /// Deletes a (single) value from a table. @@ -250,13 +319,13 @@ protected virtual string GetInsertValueStatement(string tableName, string column /// The name of the table to use. /// The WHERE clause string. /// The expected number of affected rows. - protected virtual async Task DeleteFromTableAsync(string tableName, string whereClause, int expectedRowsAffected) + protected async Task DeleteFromTableAsync(string tableName, string whereClause, int expectedRowsAffected) { string deleteNumberStatement = GetDeleteValueStatement(tableName, whereClause); OutputHelper?.WriteLine(deleteNumberStatement); Statement.SqlQuery = deleteNumberStatement; UpdateResult updateResult = await Statement.ExecuteUpdateAsync(); - Assert.Equal(expectedRowsAffected, updateResult.AffectedRows); + if (ValidateAffectedRows) Assert.Equal(expectedRowsAffected, updateResult.AffectedRows); } /// @@ -265,8 +334,7 @@ protected virtual async Task DeleteFromTableAsync(string tableName, string where /// The name of the table to use. /// The WHERE clause string. /// - protected virtual string GetDeleteValueStatement(string tableName, string whereClause) => - string.Format("DELETE FROM {0} WHERE {1};", tableName, whereClause); + protected string GetDeleteValueStatement(string tableName, string whereClause) => TestEnvironment.GetDeleteValueStatement(tableName, whereClause); /// /// Selects a single value and validates it equality with expected value and number of results. @@ -276,12 +344,26 @@ protected virtual string GetDeleteValueStatement(string tableName, string whereC /// The value to select and validate. /// The number of expected results (rows). /// - protected virtual async Task SelectAndValidateValuesAsync(string table, string columnName, object? value, int expectedLength, string? formattedValue = null) + protected async Task SelectAndValidateValuesAsync(string table, string columnName, object? value, int expectedLength, string? formattedValue = null) { string selectNumberStatement = GetSelectSingleValueStatement(table, columnName, formattedValue ?? value); await SelectAndValidateValuesAsync(selectNumberStatement, value, expectedLength); } + /// + /// Selects a single value and validates it equality with expected value and number of results. + /// + /// The name of the table to use. + /// The name of the column. + /// The value to select and validate. + /// The number of expected results (rows). + /// + protected async Task SelectAndValidateValuesAsync(string table, string columnName, object?[] values, int expectedLength, string? formattedValue = null) + { + string selectNumberStatement = GetSelectSingleValueStatement(table, columnName, formattedValue ?? values[0]); + await SelectAndValidateValuesAsync(selectNumberStatement, values, expectedLength); + } + /// /// Selects a single value and validates it equality with expected value and number of results. /// @@ -289,7 +371,29 @@ protected virtual async Task SelectAndValidateValuesAsync(string table, string c /// The value to select and validate. /// The number of expected results (rows). /// - protected virtual async Task SelectAndValidateValuesAsync(string selectStatement, object? value, int expectedLength) + protected async Task SelectAndValidateValuesAsync(string selectStatement, object? value, int expectedLength) + { + await SelectAndValidateValuesAsync(selectStatement, [value], expectedLength); + } + + private static T? ArrowArrayAs(IArrowArray arrowArray) + where T : IArrowArray + { + if (arrowArray is T t) + { + return t; + } + return default; + } + + /// + /// Selects a single value and validates it equality with expected value and number of results. + /// + /// The SQL statement to execute. + /// The array of values to select and validate. + /// The number of expected results (rows). + /// + protected async Task SelectAndValidateValuesAsync(string selectStatement, object?[] values, int expectedLength, bool hasIndexColumn = false) { Statement.SqlQuery = selectStatement; OutputHelper?.WriteLine(selectStatement); @@ -297,8 +401,23 @@ protected virtual async Task SelectAndValidateValuesAsync(string selectStatement int actualLength = 0; using (IArrowArrayStream stream = queryResult.Stream ?? throw new InvalidOperationException("stream is null")) { + Dictionary> valueGetters = new() + { + { ArrowTypeId.Decimal128, (a, i) => ArrowArrayAs(a)?.GetSqlDecimal(i) }, + { ArrowTypeId.Double, (a, i) => ArrowArrayAs(a)?.GetValue(i) }, + { ArrowTypeId.Float, (a, i) => ArrowArrayAs(a)?.GetValue(i) }, + { ArrowTypeId.Int64, (a, i) => ArrowArrayAs(a)?.GetValue(i) }, + { ArrowTypeId.Int32, (a, i) => ArrowArrayAs(a)?.GetValue(i) }, + { ArrowTypeId.Int16, (a, i) => ArrowArrayAs(a)?.GetValue(i) }, + { ArrowTypeId.Int8, (a, i) => ArrowArrayAs(a)?.GetValue(i) }, + { ArrowTypeId.String, (a, i) => ArrowArrayAs(a)?.GetString(i) }, + { ArrowTypeId.Timestamp, (a, i) => ArrowArrayAs(a)?.GetTimestamp(i) }, + { ArrowTypeId.Date32, (a, i) => ArrowArrayAs(a)?.GetDateTimeOffset(i) }, + { ArrowTypeId.Boolean, (a, i) => ArrowArrayAs(a)?.GetValue(i) }, + }; // Assume first column Field field = stream.Schema.GetFieldByIndex(0); + Int32Array? indexArray = null; while (true) { using (RecordBatch nextBatch = await stream.ReadNextRecordBatchAsync()) @@ -306,73 +425,28 @@ protected virtual async Task SelectAndValidateValuesAsync(string selectStatement if (nextBatch == null) { break; } switch (field.DataType) { - case Decimal128Type: - Decimal128Array decimalArray = (Decimal128Array)nextBatch.Column(0); - actualLength += decimalArray.Length; - ValidateValue(value, decimalArray.Length, (i) => decimalArray.GetSqlDecimal(i)); - break; - case DoubleType: - DoubleArray doubleArray = (DoubleArray)nextBatch.Column(0); - actualLength += doubleArray.Length; - ValidateValue(value, doubleArray.Length, (i) => doubleArray.GetValue(i)); - break; - case FloatType: - FloatArray floatArray = (FloatArray)nextBatch.Column(0); - actualLength += floatArray.Length; - ValidateValue(value, floatArray.Length, (i) => floatArray.GetValue(i)); - break; - case Int64Type: - Int64Array int64Array = (Int64Array)nextBatch.Column(0); - actualLength += int64Array.Length; - ValidateValue(value, int64Array.Length, (i) => int64Array.GetValue(i)); - break; - case Int32Type: - Int32Array intArray = (Int32Array)nextBatch.Column(0); - actualLength += intArray.Length; - ValidateValue(value, intArray.Length, (i) => intArray.GetValue(i)); - break; - case Int16Type: - Int16Array shortArray = (Int16Array)nextBatch.Column(0); - actualLength += shortArray.Length; - ValidateValue(value, shortArray.Length, (i) => shortArray.GetValue(i)); - break; - case Int8Type: - Int8Array tinyIntArray = (Int8Array)nextBatch.Column(0); - actualLength += tinyIntArray.Length; - ValidateValue(value, tinyIntArray.Length, (i) => tinyIntArray.GetValue(i)); - break; - case StringType: - StringArray stringArray = (StringArray)nextBatch.Column(0); - actualLength += stringArray.Length; - ValidateValue(value, stringArray.Length, (i) => stringArray.GetString(i)); - break; - case TimestampType: - TimestampArray timestampArray = (TimestampArray)nextBatch.Column(0); - actualLength += timestampArray.Length; - ValidateValue(value, timestampArray.Length, (i) => timestampArray.GetTimestamp(i)); - break; - case Date32Type: - Date32Array date32Array = (Date32Array)nextBatch.Column(0); - actualLength += date32Array.Length; - ValidateValue(value, date32Array.Length, (i) => date32Array.GetDateTimeOffset(i)); - break; - case BooleanType: - BooleanArray booleanArray = (BooleanArray)nextBatch.Column(0); - actualLength += booleanArray.Length; - ValidateValue(value, booleanArray.Length, (i) => booleanArray.GetValue(i)); - break; case BinaryType: BinaryArray binaryArray = (BinaryArray)nextBatch.Column(0); actualLength += binaryArray.Length; - ValidateValue(value, binaryArray.Length, (i) => binaryArray.GetBytes(i).ToArray()); + ValidateBinaryArrayValue((i) => values?[i], binaryArray); break; case NullType: NullArray nullArray = (NullArray)nextBatch.Column(0); actualLength += nullArray.Length; - ValidateValue(value == null, nullArray.Length, (i) => nullArray.IsNull(i)); + ValidateValue(nullArray.Length, (i) => values?[i] == null, (i) => nullArray.IsNull(i)); break; default: - Assert.Fail($"Unhandled datatype {field.DataType}"); + if (valueGetters.TryGetValue(field.DataType.TypeId, out Func? valueGetter)) + { + IArrowArray array = nextBatch.Column(0); + actualLength += array.Length; + indexArray = hasIndexColumn ? (Int32Array)nextBatch.Column(1) : null; + ValidateValue(array.Length, (i) => values?[i], (i) => valueGetter(array, i), indexArray, array.IsNull); + } + else + { + Assert.Fail($"Unhandled datatype {field.DataType}"); + } break; } @@ -385,14 +459,39 @@ protected virtual async Task SelectAndValidateValuesAsync(string selectStatement /// /// Validates a single values for all results (in the batch). /// - /// The value to validate. + /// The value to validate. + /// The binary array to validate + private static void ValidateBinaryArrayValue(Func expectedValues, BinaryArray binaryArray) + { + for (int i = 0; i < binaryArray.Length; i++) + { + // Note: null is indicated in output flag 'isNull'. + byte[] byteArray = binaryArray.GetBytes(i, out bool isNull).ToArray(); + byte[]? nullableByteArray = isNull ? null : byteArray; + var expectedValue = expectedValues(i); + Assert.Equal(expectedValue, nullableByteArray); + } + } + + /// + /// Validates a single values for all results (in the batch). + /// /// The length of the current batch/array. + /// The value to validate. /// The getter function to retrieve the actual value. - private static void ValidateValue(object? value, int length, Func getter) + /// + private static void ValidateValue(int length, Func value, Func getter, Int32Array? indexColumn = null, Func? isNullEvaluator = default) { for (int i = 0; i < length; i++) { - Assert.Equal(value, getter(i)); + int valueIndex = indexColumn?.GetValue(i) ?? i; + object? expected = value(valueIndex); + if (isNullEvaluator != null) + { + Assert.Equal(expected == null, isNullEvaluator(i)); + } + object? actual = getter(i); + Assert.Equal(expected, actual); } } @@ -403,13 +502,13 @@ private static void ValidateValue(object? value, int length, Func /// The name of the column. /// The value to select and validate. /// The native SQL statement. - protected virtual string GetSelectSingleValueStatement(string table, string columnName, object? value) => - $"SELECT {columnName} FROM {table} WHERE {GetWhereClause(columnName, value)}"; + protected string GetSelectSingleValueStatement(string table, string columnName, object? value) => + $"SELECT {columnName} FROM {table} {GetWhereClause(columnName, value)}"; - protected virtual string GetWhereClause(string columnName, object? value) => + protected string GetWhereClause(string columnName, object? value) => value == null - ? $"{columnName} IS NULL" - : string.Format("{0} = {1}", columnName, MaybeDoubleToString(value)); + ? $"WHERE {columnName} IS NULL" + : string.Format("WHERE {0} = {1}", columnName, MaybeDoubleToString(value)); private static object MaybeDoubleToString(object value) => value.GetType() == typeof(float) @@ -428,11 +527,11 @@ protected static string ConvertDoubleToString(double value) switch (value) { case double.PositiveInfinity: - return "'inf'"; + return "double('infinity')"; case double.NegativeInfinity: - return "'-inf'"; + return "double('-infinity')"; case double.NaN: - return "'NaN'"; + return "double('NaN')"; #if NET472 // Double.ToString() rounds max/min values, causing Snowflake to store +/- infinity case double.MaxValue: @@ -455,11 +554,11 @@ protected static string ConvertFloatToString(float value) switch (value) { case float.PositiveInfinity: - return "'inf'"; + return "float('infinity')"; case float.NegativeInfinity: - return "'-inf'"; + return "float('-infinity')"; case float.NaN: - return "'NaN'"; + return "float('NaN')"; #if NET472 // Float.ToString() rounds max/min values, causing Snowflake to store +/- infinity case float.MaxValue: @@ -482,15 +581,13 @@ protected virtual void Dispose(bool disposing) { if (disposing) { - if (_statement != null) + if (_statement.IsValueCreated) { - _statement.Dispose(); - _statement = null; + _statement.Value.Dispose(); } - if (_connection != null) + if (_connection.IsValueCreated) { - _connection.Dispose(); - _connection = null; + _connection.Value.Dispose(); } } @@ -510,12 +607,12 @@ protected static string QuoteValue(string value) return $"'{value.Replace("'", "''")}'"; } - protected virtual string DelimitIdentifier(string value) + protected string DelimitIdentifier(string value) { return $"{Delimiter}{value.Replace(Delimiter, $"{Delimiter}{Delimiter}")}{Delimiter}"; } - protected virtual string Delimiter => "\""; + protected string Delimiter => TestEnvironment.Delimiter; protected static void AssertContainsAll(string[] expectedTexts, string value) { @@ -533,15 +630,18 @@ protected static void AssertContainsAll(string[] expectedTexts, string value) /// An enumeration of patterns to match produced from the identifier. protected static IEnumerable GetPatterns(string? name) { - if (string.IsNullOrEmpty(name)) yield break; + if (name == null) yield break; yield return new object[] { name! }; - yield return new object[] { $"{GetPartialNameForPatternMatch(name!)}%" }; - yield return new object[] { $"{GetPartialNameForPatternMatch(name!).ToLower()}%" }; - yield return new object[] { $"{GetPartialNameForPatternMatch(name!).ToUpper()}%" }; - yield return new object[] { $"_{GetNameWithoutFirstChatacter(name!)}" }; - yield return new object[] { $"_{GetNameWithoutFirstChatacter(name!).ToLower()}" }; - yield return new object[] { $"_{GetNameWithoutFirstChatacter(name!).ToUpper()}" }; + if (!string.IsNullOrEmpty(name)) + { + yield return new object[] { $"{GetPartialNameForPatternMatch(name!)}%" }; + yield return new object[] { $"{GetPartialNameForPatternMatch(name!).ToLower()}%" }; + yield return new object[] { $"{GetPartialNameForPatternMatch(name!).ToUpper()}%" }; + yield return new object[] { $"_{GetNameWithoutFirstChatacter(name!)}" }; + yield return new object[] { $"_{GetNameWithoutFirstChatacter(name!).ToLower()}" }; + yield return new object[] { $"_{GetNameWithoutFirstChatacter(name!).ToUpper()}" }; + } } private static string GetPartialNameForPatternMatch(string name) @@ -561,7 +661,7 @@ private static string GetNameWithoutFirstChatacter(string name) /// /// Represents a temporary table that can create and drop the table automatically. /// - protected class TemporaryTable : IDisposable + public class TemporaryTable : IDisposable { private bool _disposedValue; private readonly AdbcStatement _statement; @@ -584,7 +684,7 @@ private TemporaryTable(AdbcStatement statement, string tableName) /// The name of temporary table to create. /// The SQL query to create the table in the native SQL dialect. /// - public static async ValueTask NewTemporaryTableAsync(AdbcStatement statement, string tableName, string sqlUpdate) + public static async Task NewTemporaryTableAsync(AdbcStatement statement, string tableName, string sqlUpdate) { statement.SqlQuery = sqlUpdate; await statement.ExecuteUpdateAsync(); @@ -636,7 +736,8 @@ private TemporarySchema(string catalogName, AdbcStatement statement) public static async ValueTask NewTemporarySchemaAsync(string catalogName, AdbcStatement statement) { TemporarySchema schema = new TemporarySchema(catalogName, statement); - statement.SqlQuery = $"CREATE SCHEMA IF NOT EXISTS {schema.CatalogName}.{schema.SchemaName}"; + string catalog = string.IsNullOrEmpty(schema.CatalogName) ? string.Empty : schema.CatalogName + "."; + statement.SqlQuery = $"CREATE SCHEMA IF NOT EXISTS {catalog}{schema.SchemaName}"; await statement.ExecuteUpdateAsync(); return schema; } @@ -651,7 +752,8 @@ protected virtual void Dispose(bool disposing) { if (disposing) { - _statement.SqlQuery = $"DROP SCHEMA IF EXISTS {CatalogName}.{SchemaName}"; + string catalog = string.IsNullOrEmpty(CatalogName) ? string.Empty : CatalogName + "."; + _statement.SqlQuery = $"DROP SCHEMA IF EXISTS {catalog}{SchemaName}"; _statement.ExecuteUpdate(); } diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/TestConfiguration.cs b/csharp/test/Apache.Arrow.Adbc.Tests/TestConfiguration.cs index 6365acd502..36a46a55a3 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/TestConfiguration.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/TestConfiguration.cs @@ -15,6 +15,7 @@ * limitations under the License. */ +using System; using System.Text.Json.Serialization; namespace Apache.Arrow.Adbc.Tests @@ -22,7 +23,7 @@ namespace Apache.Arrow.Adbc.Tests /// /// Base test configuration values. /// - public abstract class TestConfiguration + public abstract class TestConfiguration : ICloneable { /// /// The query to run. @@ -41,6 +42,8 @@ public abstract class TestConfiguration /// [JsonPropertyName("metadata")] public TestMetadata Metadata { get; set; } = new TestMetadata(); + + public virtual object Clone() => MemberwiseClone(); } /// diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/TestEnvironment.cs b/csharp/test/Apache.Arrow.Adbc.Tests/TestEnvironment.cs new file mode 100644 index 0000000000..374c102f37 --- /dev/null +++ b/csharp/test/Apache.Arrow.Adbc.Tests/TestEnvironment.cs @@ -0,0 +1,93 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Apache.Arrow.Adbc.Tests +{ + public abstract class TestEnvironment where TConfig : TestConfiguration + { + private readonly Func _getConnection; + + public abstract class Factory where TEnv : TestEnvironment + { + public abstract TEnv Create(Func getConnection); + } + + protected TestEnvironment(Func getConnection) + { + _getConnection = getConnection; + } + + public abstract string TestConfigVariable { get; } + + public abstract string VendorVersion { get; } + + public abstract string SqlDataResourceLocation { get; } + + public abstract int ExpectedColumnCount { get; } + + public virtual bool SupportsDelete => true; + + public virtual bool SupportsUpdate => true; + + public virtual bool SupportCatalogName => true; + + public virtual bool ValidateAffectedRows => true; + + public abstract AdbcDriver CreateNewDriver(); + + public abstract SampleDataBuilder GetSampleDataBuilder(); + + public abstract Dictionary GetDriverParameters(TConfig testConfiguration); + + public virtual string GetCreateTemporaryTableStatement(string tableName, string columns) + { + return string.Format("CREATE TEMPORARY IF NOT EXISTS TABLE {0} ({1})", tableName, columns); + } + + public virtual string Delimiter => "\""; + + public virtual string GetInsertStatement(string tableName, string columnName, string? value) => + string.Format("INSERT INTO {0} ({1}) VALUES ({2});", tableName, columnName, value ?? "NULL"); + + public virtual string GetDeleteValueStatement(string tableName, string whereClause) => + string.Format("DELETE FROM {0} {1};", tableName, whereClause); + + public string GetInsertStatementWithIndexColumn(string tableName, string columnName, string indexColumnName, object?[] values, string?[]? formattedValues) + { + var completeValues = new StringBuilder(); + if (values.Length == 0) throw new ArgumentOutOfRangeException(nameof(values), values.Length, "Must provide a non-zero length array of test values."); + for (int i = 0; i < values.Length; i++) + { + object? value = values[i]; + string? formattedValue = formattedValues?[i]; + string separator = (completeValues.Length != 0) ? ", " : ""; + completeValues.AppendLine($"{separator}({i}, {formattedValue ?? value?.ToString() ?? "NULL"})"); + } + + string insertStatement = $"INSERT INTO {tableName} ({indexColumnName}, {columnName}) VALUES {completeValues}"; + return insertStatement; + } + + public Version VendorVersionAsVersion => new Lazy(() => new Version(VendorVersion)).Value; + + public AdbcConnection Connection => _getConnection(); + } +} diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Utils.cs b/csharp/test/Apache.Arrow.Adbc.Tests/Utils.cs index e07c14f7a9..07eb466b93 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/Utils.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Utils.cs @@ -102,5 +102,18 @@ public static T GetTestConfiguration(string? fileName) return testConfiguration; } + + /// + /// Formats a message to include a specific environment (if present). + /// + /// The message to format. + /// The name of the environment. + public static string FormatMessage(string message, string? environmentName) + { + if (!string.IsNullOrEmpty(environmentName)) + return $"{message} in the [{environmentName}] environment"; + + return message; + } } } diff --git a/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj b/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj index 63e6f2c87e..35223044db 100644 --- a/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj +++ b/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj @@ -6,17 +6,14 @@ - - - - - - - + + + all runtime; build; native; contentfiles; analyzers; buildtransitive - + + @@ -31,19 +28,24 @@ - + + PreserveNewest + + + PreserveNewest + + PreserveNewest PreserveNewest + + PreserveNewest + PreserveNewest - - - - diff --git a/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs b/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs index 81e7355817..ea3d7d16ec 100644 --- a/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs +++ b/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs @@ -27,11 +27,40 @@ public class ApacheTestConfiguration : TestConfiguration [JsonPropertyName("port")] public string Port { get; set; } = string.Empty; - [JsonPropertyName("token"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] - public string Token { get; set; } = string.Empty; - [JsonPropertyName("path"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public string Path { get; set; } = string.Empty; + [JsonPropertyName("username"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string Username { get; set; } = string.Empty; + + [JsonPropertyName("password"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string Password { get; set; } = string.Empty; + + [JsonPropertyName("auth_type"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string AuthType { get; set; } = string.Empty; + + [JsonPropertyName("uri"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string Uri { get; set; } = string.Empty; + + [JsonPropertyName("batch_size"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string BatchSize { get; set; } = string.Empty; + + [JsonPropertyName("polltime_ms"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string PollTimeMilliseconds { get; set; } = string.Empty; + + [JsonPropertyName("connect_timeout_ms"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string ConnectTimeoutMilliseconds { get; set; } = string.Empty; + + [JsonPropertyName("query_timeout_s"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string QueryTimeoutSeconds { get; set; } = string.Empty; + + [JsonPropertyName("type"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string Type { get; set; } = string.Empty; + + [JsonPropertyName("data_type_conv"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string DataTypeConversion { get; set; } = string.Empty; + + [JsonPropertyName("tls_options"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string TlsOptions { get; set; } = string.Empty; } } diff --git a/csharp/test/Drivers/Apache/Common/BinaryBooleanValueTests.cs b/csharp/test/Drivers/Apache/Common/BinaryBooleanValueTests.cs new file mode 100644 index 0000000000..fb6d355443 --- /dev/null +++ b/csharp/test/Drivers/Apache/Common/BinaryBooleanValueTests.cs @@ -0,0 +1,154 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common +{ + // TODO: When supported, use prepared statements instead of SQL string literals + // Which will better test how the driver handles values sent/received + + /// + /// Validates that specific binary and boolean values can be inserted, retrieved and targeted correctly + /// + public abstract class BinaryBooleanValueTests : TestBase + where TConfig : TestConfiguration + where TEnv : HiveServer2TestEnvironment + { + public BinaryBooleanValueTests(ITestOutputHelper output, TestEnvironment.Factory testEnvFactory) + : base(output, testEnvFactory) { } + + public static IEnumerable ByteArrayData(int size) + { + var rnd = new Random(); + byte[] bytes = new byte[size]; + rnd.NextBytes(bytes); + yield return new object[] { bytes }; + } + + /// + /// Validates if driver can send and receive specific Binary values correctly. + /// + [SkippableTheory] + [InlineData(null)] + [MemberData(nameof(ByteArrayData), 0)] + [MemberData(nameof(ByteArrayData), 2)] + [MemberData(nameof(ByteArrayData), 1024)] + public async Task TestBinaryData(byte[]? value) + { + string columnName = "BINARYTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "BINARY")); + string? formattedValue = value != null ? $"X'{BitConverter.ToString(value).Replace("-", "")}'" : null; + await ValidateInsertSelectDeleteSingleValueAsync( + table.TableName, + columnName, + value, + formattedValue); + } + + /// + /// Validates if driver can send and receive specific Boolean values correctly. + /// + [SkippableTheory] + [InlineData(null)] + [InlineData(true)] + [InlineData(false)] + public async Task TestBooleanData(bool? value) + { + string columnName = "BOOLEANTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "BOOLEAN")); + string? formattedValue = value == null ? null : $"{value?.ToString(CultureInfo.InvariantCulture)}"; + await ValidateInsertSelectDeleteTwoValuesAsync( + table.TableName, + columnName, + value, + formattedValue); + } + + /// + /// Validates if driver can receive specific NULL values correctly. + /// + [SkippableTheory] + [InlineData("NULL")] + [InlineData("CAST(NULL AS INT)")] + [InlineData("CAST(NULL AS BIGINT)")] + [InlineData("CAST(NULL AS SMALLINT)")] + [InlineData("CAST(NULL AS TINYINT)")] + [InlineData("CAST(NULL AS FLOAT)")] + [InlineData("CAST(NULL AS DOUBLE)")] + [InlineData("CAST(NULL AS DECIMAL(38,0))")] + [InlineData("CAST(NULL AS STRING)")] + [InlineData("CAST(NULL AS VARCHAR(10))")] + [InlineData("CAST(NULL AS CHAR(10))")] + [InlineData("CAST(NULL AS BOOLEAN)")] + [InlineData("CAST(NULL AS BINARY)")] + [InlineData("CAST(NULL AS MAP)")] + [InlineData("CAST(NULL AS STRUCT)")] + [InlineData("CAST(NULL AS ARRAY)")] + public async Task TestNullData(string projectionClause) + { + string selectStatement = $"SELECT {projectionClause};"; + // Note: by default, this returns as String type, not NULL type. + await SelectAndValidateValuesAsync(selectStatement, (object?)null, 1); + } + + [SkippableTheory] + [InlineData(1)] + [InlineData(7)] + [InlineData(8)] + [InlineData(9)] + [InlineData(15)] + [InlineData(16)] + [InlineData(17)] + [InlineData(23)] + [InlineData(24)] + [InlineData(25)] + [InlineData(31)] + [InlineData(32)] // Full integer + [InlineData(33)] + [InlineData(39)] + [InlineData(40)] + [InlineData(41)] + [InlineData(47)] + [InlineData(48)] + [InlineData(49)] + [InlineData(63)] + [InlineData(64)] // Full 2 integers + [InlineData(65)] + public async Task TestMultilineNullData(int numberOfValues) + { + Random rnd = new(); + int percentIsNull = 50; + + object?[] values = new object?[numberOfValues]; + for (int i = 0; i < numberOfValues; i++) + { + values[i] = rnd.Next(0, 100) < percentIsNull ? null : rnd.Next(0, 2) != 0; + } + string columnName = "BOOLEANTYPE"; + string indexColumnName = "INDEXCOL"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}, {2} {3}", indexColumnName, "INT", columnName, "BOOLEAN")); + await ValidateInsertSelectDeleteMultipleValuesAsync(table.TableName, columnName, indexColumnName, values); + } + } +} diff --git a/csharp/test/Drivers/Apache/Common/ClientTests.cs b/csharp/test/Drivers/Apache/Common/ClientTests.cs new file mode 100644 index 0000000000..9148d72811 --- /dev/null +++ b/csharp/test/Drivers/Apache/Common/ClientTests.cs @@ -0,0 +1,236 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using Apache.Arrow.Adbc.Client; +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tests.Xunit; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common +{ + /// + /// Class for testing the ADBC Client using the Spark ADBC driver. + /// + /// + /// Tests are ordered to ensure data is created for the other + /// queries to run. + /// Note: This test creates/replaces the table identified in the configuration (metadata/table). + /// It uses the test collection "TableCreateTestCollection" to ensure it does not run + /// as the same time as any other tests that may create/update the same table. + /// + [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] + [Collection("TableCreateTestCollection")] + public abstract class ClientTests : TestBase + where TConfig : TestConfiguration + where TEnv : HiveServer2TestEnvironment + { + public ClientTests(ITestOutputHelper? outputHelper, TestEnvironment.Factory testEnvFactory) + : base(outputHelper, testEnvFactory) + { + Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); + } + + /// + /// Validates if the client execute updates. + /// + [SkippableFact, Order(1)] + public void CanClientExecuteUpdate() + { + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection()) + { + adbcConnection.Open(); + + string[] queries = GetQueries(); + + var expectedResults = GetUpdateExpectedResults(); + Tests.ClientTests.CanClientExecuteUpdate(adbcConnection, TestConfiguration, queries, expectedResults); + } + } + + protected abstract IReadOnlyList GetUpdateExpectedResults(); + + /// + /// Validates if the client can get the schema. + /// + [SkippableFact, Order(2)] + public void CanClientGetSchema() + { + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection()) + { + Tests.ClientTests.CanClientGetSchema(adbcConnection, TestConfiguration, $"SELECT * FROM {TestConfiguration.Metadata.Table}"); + } + } + + /// + /// Validates if the client can connect to a live server and + /// parse the results. + /// + [SkippableFact, Order(3)] + public void CanClientExecuteQuery() + { + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection()) + { + Tests.ClientTests.CanClientExecuteQuery(adbcConnection, TestConfiguration); + } + } + + /// + /// Validates if the client can connect to a live server and + /// parse the results. + /// + [SkippableFact, Order(5)] + public void CanClientExecuteEmptyQuery() + { + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection()) + { + Tests.ClientTests.CanClientExecuteQuery( + adbcConnection, + TestConfiguration, + customQuery: $"SELECT * FROM {TestConfiguration.Metadata.Table} WHERE FALSE", + expectedResultsCount: 0); + } + } + + /// + /// Validates if the client is retrieving and converting values + /// to the expected types. + /// + [SkippableFact, Order(4)] + public void VerifyTypesAndValues() + { + using (Adbc.Client.AdbcConnection dbConnection = GetAdbcConnection()) + { + SampleDataBuilder sampleDataBuilder = GetSampleDataBuilder(); + + Tests.ClientTests.VerifyTypesAndValues(dbConnection, sampleDataBuilder); + } + } + + [SkippableFact] + public void VerifySchemaTablesWithNoConstraints() + { + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection(includeTableConstraints: false)) + { + adbcConnection.Open(); + + string schema = "Tables"; + + var tables = adbcConnection.GetSchema(schema); + + Assert.True(tables.Rows.Count > 0, $"No tables were found in the schema '{schema}'"); + } + } + + [SkippableFact] + public void VerifySchemaTables() + { + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection()) + { + adbcConnection.Open(); + + var collections = adbcConnection.GetSchema("MetaDataCollections"); + Assert.Equal(7, collections.Rows.Count); + Assert.Equal(2, collections.Columns.Count); + + var restrictions = adbcConnection.GetSchema("Restrictions"); + Assert.Equal(11, restrictions.Rows.Count); + Assert.Equal(3, restrictions.Columns.Count); + + var catalogs = adbcConnection.GetSchema("Catalogs"); + Assert.Single(catalogs.Columns); + var catalog = (string?)catalogs.Rows[0].ItemArray[0]; + + catalogs = adbcConnection.GetSchema("Catalogs", new[] { catalog }); + Assert.Equal(1, catalogs.Rows.Count); + + string random = "X" + Guid.NewGuid().ToString("N"); + + catalogs = adbcConnection.GetSchema("Catalogs", new[] { random }); + Assert.Equal(0, catalogs.Rows.Count); + + var schemas = adbcConnection.GetSchema("Schemas", new[] { catalog }); + Assert.Equal(2, schemas.Columns.Count); + var schema = (string?)schemas.Rows[0].ItemArray[1]; + + schemas = adbcConnection.GetSchema("Schemas", new[] { catalog, schema }); + Assert.Equal(1, schemas.Rows.Count); + + schemas = adbcConnection.GetSchema("Schemas", new[] { random }); + Assert.Equal(0, schemas.Rows.Count); + + schemas = adbcConnection.GetSchema("Schemas", new[] { catalog, random }); + Assert.Equal(0, schemas.Rows.Count); + + schemas = adbcConnection.GetSchema("Schemas", new[] { random, random }); + Assert.Equal(0, schemas.Rows.Count); + + var tableTypes = adbcConnection.GetSchema("TableTypes"); + Assert.Single(tableTypes.Columns); + + var tables = adbcConnection.GetSchema("Tables", new[] { catalog, schema }); + Assert.Equal(4, tables.Columns.Count); + + tables = adbcConnection.GetSchema("Tables", new[] { catalog, random }); + Assert.Equal(0, tables.Rows.Count); + + tables = adbcConnection.GetSchema("Tables", new[] { random, schema }); + Assert.Equal(0, tables.Rows.Count); + + tables = adbcConnection.GetSchema("Tables", new[] { random, random }); + Assert.Equal(0, tables.Rows.Count); + + tables = adbcConnection.GetSchema("Tables", new[] { catalog, schema, random }); + Assert.Equal(0, tables.Rows.Count); + + var columns = adbcConnection.GetSchema("Columns", new[] { catalog, schema }); + Assert.Equal(16, columns.Columns.Count); + } + } + + [SkippableFact] + public void VerifyTimeoutsSet() + { + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection()) + { + int timeout = 99; + using AdbcCommand cmd = adbcConnection.CreateCommand(); + + // setting the timout before the property value + Assert.Throws(() => + { + cmd.CommandTimeout = 1; + }); + + cmd.AdbcCommandTimeoutProperty = "adbc.apache.statement.query_timeout_s"; + cmd.CommandTimeout = timeout; + + Assert.True(cmd.CommandTimeout == timeout, $"ConnectionTimeout is not set to {timeout}"); + } + } + + private Adbc.Client.AdbcConnection GetAdbcConnection(bool includeTableConstraints = true) + { + return new Adbc.Client.AdbcConnection( + NewDriver, GetDriverParameters(TestConfiguration), + [] + ); + } + } +} diff --git a/csharp/test/Drivers/Apache/Common/ComplexTypesValueTests.cs b/csharp/test/Drivers/Apache/Common/ComplexTypesValueTests.cs new file mode 100644 index 0000000000..9a785c0604 --- /dev/null +++ b/csharp/test/Drivers/Apache/Common/ComplexTypesValueTests.cs @@ -0,0 +1,81 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common +{ + // TODO: When supported, use prepared statements instead of SQL string literals + // Which will better test how the driver handles values sent/received + + /// + /// Validates that specific complex structured types can be inserted, retrieved and targeted correctly + /// + public class ComplexTypesValueTests : TestBase + where TConfig : TestConfiguration + where TEnv : HiveServer2TestEnvironment + { + public ComplexTypesValueTests(ITestOutputHelper output, TestEnvironment.Factory testEnvFactory) + : base(output, testEnvFactory) { } + + /// + /// Validates if driver can send and receive specific array of integer values correctly. + /// + [SkippableTheory] + [InlineData("ARRAY(CAST(1 AS INT), 2, 3)", "[1,2,3]")] + [InlineData("ARRAY(CAST(1 AS LONG), 2, 3)", "[1,2,3]")] + [InlineData("ARRAY(CAST(1 AS DOUBLE), 2, 3)", "[1.0,2.0,3.0]")] + [InlineData("ARRAY(CAST(1 AS NUMERIC(38,0)), 2, 3)", "[1,2,3]")] + [InlineData("ARRAY(CAST('John Doe' AS STRING), 2, 3)", """["John Doe","2","3"]""")] + // Note: Timestamp returned adjusted to UTC. + [InlineData("ARRAY(CAST('2024-01-01T00:00:00-07:00' AS TIMESTAMP), CAST('2024-02-02T02:02:02+01:30' AS TIMESTAMP), CAST('2024-03-03T03:03:03Z' AS TIMESTAMP))", """[2024-01-01 07:00:00,2024-02-02 00:32:02,2024-03-03 03:03:03]""")] + [InlineData("ARRAY(CAST('2024-01-01T00:00:00Z' AS DATE), CAST('2024-02-02T02:02:02Z' AS DATE), CAST('2024-03-03T03:03:03Z' AS DATE))", """[2024-01-01,2024-02-02,2024-03-03]""")] + [InlineData("ARRAY(INTERVAL 123 YEARS 11 MONTHS, INTERVAL 5 YEARS, INTERVAL 6 MONTHS)", """[123-11,5-0,0-6]""")] + public async Task TestArrayData(string projection, string value) + { + string selectStatement = $"SELECT {projection};"; + await SelectAndValidateValuesAsync(selectStatement, value, 1); + } + + /// + /// Validates if driver can send and receive specific map values correctly. + /// + [SkippableTheory] + [InlineData("MAP(1, 'John Doe', 2, 'Jane Doe', 3, 'Jack Doe')", """{1:"John Doe",2:"Jane Doe",3:"Jack Doe"}""")] + [InlineData("MAP('John Doe', 1, 'Jane Doe', 2, 'Jack Doe', 3)", """{"Jack Doe":3,"Jane Doe":2,"John Doe":1}""")] + public async Task TestMapData(string projection, string value) + { + string selectStatement = $"SELECT {projection};"; + await SelectAndValidateValuesAsync(selectStatement, value, 1); + } + + /// + /// Validates if driver can send and receive specific map values correctly. + /// + [SkippableTheory] + [InlineData("STRUCT(CAST(1 AS INT), CAST('John Doe' AS STRING))", """{"col1":1,"col2":"John Doe"}""")] + [InlineData("STRUCT(CAST('John Doe' AS STRING), CAST(1 AS INT))", """{"col1":"John Doe","col2":1}""")] + public async Task TestStructData(string projection, string value) + { + string selectStatement = $"SELECT {projection};"; + await SelectAndValidateValuesAsync(selectStatement, value, 1); + } + } +} diff --git a/csharp/test/Drivers/Apache/Common/DateTimeValueTests.cs b/csharp/test/Drivers/Apache/Common/DateTimeValueTests.cs new file mode 100644 index 0000000000..089de26a6c --- /dev/null +++ b/csharp/test/Drivers/Apache/Common/DateTimeValueTests.cs @@ -0,0 +1,145 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common +{ + // TODO: When supported, use prepared statements instead of SQL string literals + // Which will better test how the driver handles values sent/received + + /// + /// Validates that specific date, timestamp and interval values can be inserted, retrieved and targeted correctly + /// + public abstract class DateTimeValueTests : TestBase + where TConfig : TestConfiguration + where TEnv : HiveServer2TestEnvironment + { + // Spark handles microseconds but not nanoseconds. Truncated to 6 decimal places. + const string DateTimeZoneFormat = "yyyy-MM-dd'T'HH:mm:ss'.'ffffffK"; + const string DateTimeFormat = "yyyy-MM-dd' 'HH:mm:ss"; + protected const string DateFormat = "yyyy-MM-dd"; + + private static readonly DateTimeOffset[] s_timestampValues = + [ +#if NET5_0_OR_GREATER + DateTimeOffset.UnixEpoch, +#endif + DateTimeOffset.MinValue, + DateTimeOffset.MaxValue, + DateTimeOffset.UtcNow, + DateTimeOffset.UtcNow.ToOffset(TimeSpan.FromHours(4)) + ]; + + public DateTimeValueTests(ITestOutputHelper output, TestEnvironment.Factory testEnvFactory) + : base(output, testEnvFactory) { } + + /// + /// Validates if driver can send and receive specific Timstamp values correctly + /// + [SkippableTheory] + [MemberData(nameof(TimestampData), "TIMESTAMP")] + public async Task TestTimestampData(DateTimeOffset value, string columnType) + { + string columnName = "TIMESTAMPTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, columnType)); + + string format = TestEnvironment.GetValueForProtocolVersion(DateTimeFormat, DateTimeZoneFormat)!; + string formattedValue = $"{value.ToString(format, CultureInfo.InvariantCulture)}"; + DateTimeOffset truncatedValue = DateTimeOffset.ParseExact(formattedValue, format, CultureInfo.InvariantCulture); + + object expectedValue = TestEnvironment.GetValueForProtocolVersion(formattedValue, truncatedValue)!; + await ValidateInsertSelectDeleteSingleValueAsync( + table.TableName, + columnName, + expectedValue, + "TO_TIMESTAMP(" + QuoteValue(formattedValue) + ")"); + } + + /// + /// Validates if driver can send and receive specific no timezone Timstamp values correctly + /// + [SkippableTheory] + [MemberData(nameof(TimestampData), "DATE")] + public async Task TestDateData(DateTimeOffset value, string columnType) + { + string columnName = "DATETYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, columnType)); + + string formattedValue = $"{value.ToString(DateFormat, CultureInfo.InvariantCulture)}"; + DateTimeOffset truncatedValue = DateTimeOffset.ParseExact(formattedValue, DateFormat, CultureInfo.InvariantCulture); + + // Remove timezone offset + object expectedValue = TestEnvironment.GetValueForProtocolVersion(formattedValue, new DateTimeOffset(truncatedValue.DateTime, TimeSpan.Zero))!; + await ValidateInsertSelectDeleteSingleValueAsync( + table.TableName, + columnName, + expectedValue, + "TO_DATE(" + QuoteValue(formattedValue) + ")"); + } + + /// + /// Tests INTERVAL data types (YEAR-MONTH and DAY-SECOND). + /// + /// The INTERVAL to test. + /// The expected return value. + /// + [SkippableTheory] + [InlineData("INTERVAL 1 YEAR", "1-0")] + [InlineData("INTERVAL 1 YEAR 2 MONTH", "1-2")] + [InlineData("INTERVAL 2 MONTHS", "0-2")] + [InlineData("INTERVAL -1 YEAR", "-1-0")] + [InlineData("INTERVAL -1 YEAR 2 MONTH", "-0-10")] + [InlineData("INTERVAL -2 YEAR 2 MONTH", "-1-10")] + [InlineData("INTERVAL 1 YEAR -2 MONTH", "0-10")] + [InlineData("INTERVAL 178956970 YEAR", "178956970-0")] + [InlineData("INTERVAL 178956969 YEAR 11 MONTH", "178956969-11")] + [InlineData("INTERVAL -178956970 YEAR", "-178956970-0")] + [InlineData("INTERVAL 0 DAYS 0 HOURS 0 MINUTES 0 SECONDS", "0 00:00:00.000000000")] + [InlineData("INTERVAL 1 DAYS", "1 00:00:00.000000000")] + [InlineData("INTERVAL 2 HOURS", "0 02:00:00.000000000")] + [InlineData("INTERVAL 3 MINUTES", "0 00:03:00.000000000")] + [InlineData("INTERVAL 4 SECONDS", "0 00:00:04.000000000")] + [InlineData("INTERVAL 1 DAYS 2 HOURS", "1 02:00:00.000000000")] + [InlineData("INTERVAL 1 DAYS 2 HOURS 3 MINUTES", "1 02:03:00.000000000")] + [InlineData("INTERVAL 1 DAYS 2 HOURS 3 MINUTES 4 SECONDS", "1 02:03:04.000000000")] + [InlineData("INTERVAL 1 DAYS 2 HOURS 3 MINUTES 4.123123123 SECONDS", "1 02:03:04.123123000")] // Only to microseconds + [InlineData("INTERVAL 106751990 DAYS 23 HOURS 59 MINUTES 59.999999 SECONDS", "106751990 23:59:59.999999000")] + [InlineData("INTERVAL 106751991 DAYS 0 HOURS 0 MINUTES 0 SECONDS", "106751991 00:00:00.000000000")] + [InlineData("INTERVAL -106751991 DAYS 0 HOURS 0 MINUTES 0 SECONDS", "-106751991 00:00:00.000000000")] + [InlineData("INTERVAL -106751991 DAYS 23 HOURS 59 MINUTES 59.999999 SECONDS", "-106751990 00:00:00.000001000")] + public async Task TestIntervalData(string intervalClause, string value) + { + string selectStatement = $"SELECT {intervalClause} AS INTERVAL_VALUE;"; + await SelectAndValidateValuesAsync(selectStatement, value, 1); + } + + public static IEnumerable TimestampData(string columnType) + { + foreach (DateTimeOffset timestamp in s_timestampValues) + { + yield return new object[] { timestamp, columnType }; + } + } + } +} diff --git a/csharp/test/Drivers/Apache/Common/DriverTests.cs b/csharp/test/Drivers/Apache/Common/DriverTests.cs new file mode 100644 index 0000000000..33c466cfc4 --- /dev/null +++ b/csharp/test/Drivers/Apache/Common/DriverTests.cs @@ -0,0 +1,672 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tests.Metadata; +using Apache.Arrow.Adbc.Tests.Xunit; +using Apache.Arrow.Ipc; +using Xunit; +using Xunit.Abstractions; +using ColumnTypeId = Apache.Arrow.Adbc.Drivers.Apache.Spark.SparkConnection.ColumnTypeId; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common +{ + /// + /// Class for testing the Spark ADBC driver connection tests. + /// + /// + /// Tests are ordered to ensure data is created for the other + /// queries to run. + /// Note: This test creates/replaces the table identified in the configuration (metadata/table). + /// It uses the test collection "TableCreateTestCollection" to ensure it does not run + /// as the same time as any other tests that may create/update the same table. + /// + [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] + [Collection("TableCreateTestCollection")] + public abstract class DriverTests : TestBase + where TConfig : ApacheTestConfiguration + where TEnv : HiveServer2TestEnvironment + { + /// + /// Supported Spark data types as a subset of + /// + private enum SupportedSparkDataType : short + { + ARRAY = ColumnTypeId.ARRAY, + BIGINT = ColumnTypeId.BIGINT, + BINARY = ColumnTypeId.BINARY, + BOOLEAN = ColumnTypeId.BOOLEAN, + CHAR = ColumnTypeId.CHAR, + DATE = ColumnTypeId.DATE, + DECIMAL = ColumnTypeId.DECIMAL, + DOUBLE = ColumnTypeId.DOUBLE, + FLOAT = ColumnTypeId.FLOAT, + INTEGER = ColumnTypeId.INTEGER, + JAVA_OBJECT = ColumnTypeId.JAVA_OBJECT, + LONGNVARCHAR = ColumnTypeId.LONGNVARCHAR, + LONGVARBINARY = ColumnTypeId.LONGVARBINARY, + LONGVARCHAR = ColumnTypeId.LONGVARCHAR, + NCHAR = ColumnTypeId.NCHAR, + NULL = ColumnTypeId.NULL, + NUMERIC = ColumnTypeId.NUMERIC, + NVARCHAR = ColumnTypeId.NVARCHAR, + REAL = ColumnTypeId.REAL, + SMALLINT = ColumnTypeId.SMALLINT, + STRUCT = ColumnTypeId.STRUCT, + TIMESTAMP = ColumnTypeId.TIMESTAMP, + TINYINT = ColumnTypeId.TINYINT, + VARBINARY = ColumnTypeId.VARBINARY, + VARCHAR = ColumnTypeId.VARCHAR, + } + + private static List DefaultTableTypes => new() { "TABLE", "VIEW" }; + + public DriverTests(ITestOutputHelper? outputHelper, TestEnvironment.Factory testEnvFactory) + : base(outputHelper, testEnvFactory) + { + Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); + } + + /// + /// Validates if the driver can execute update statements. + /// + [SkippableFact, Order(1)] + public void CanExecuteUpdate() + { + AdbcConnection adbcConnection = NewConnection(); + + string[] queries = GetQueries(); + + //List expectedResults = TestEnvironment.ServerType != SparkServerType.Databricks + // ? + // [ + // -1, // CREATE TABLE + // 1, // INSERT + // 1, // INSERT + // 1, // INSERT + // //1, // UPDATE + // //1, // DELETE + // ] + // : + // [ + // -1, // CREATE TABLE + // 1, // INSERT + // 1, // INSERT + // 1, // INSERT + // 1, // UPDATE + // 1, // DELETE + // ]; + + var expectedResults = GetUpdateExpectedResults(); + for (int i = 0; i < queries.Length; i++) + { + string query = queries[i]; + using AdbcStatement statement = adbcConnection.CreateStatement(); + statement.SqlQuery = query; + + UpdateResult updateResult = statement.ExecuteUpdate(); + + if (ValidateAffectedRows) Assert.Equal(expectedResults[i], updateResult.AffectedRows); + } + } + + protected abstract IReadOnlyList GetUpdateExpectedResults(); + + /// + /// Validates if the driver can call GetInfo. + /// + [SkippableFact, Order(2)] + public async Task CanGetInfo() + { + AdbcConnection adbcConnection = NewConnection(); + + // Test the supported info codes + List handledCodes = new List() + { + AdbcInfoCode.DriverName, + AdbcInfoCode.DriverVersion, + AdbcInfoCode.VendorName, + AdbcInfoCode.DriverArrowVersion, + AdbcInfoCode.VendorVersion, + AdbcInfoCode.VendorSql + }; + using IArrowArrayStream stream = adbcConnection.GetInfo(handledCodes); + + RecordBatch recordBatch = await stream.ReadNextRecordBatchAsync(); + UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name"); + + List expectedValues = new List() + { + "DriverName", + "DriverVersion", + "VendorName", + "DriverArrowVersion", + "VendorVersion", + "VendorSql" + }; + + for (int i = 0; i < infoNameArray.Length; i++) + { + AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i); + DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value"); + + Assert.Contains(value.ToString(), expectedValues); + + switch (value) + { + case AdbcInfoCode.VendorSql: + // TODO: How does external developer know the second field is the boolean field? + BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1]; + bool? boolValue = booleanArray.GetValue(i); + OutputHelper?.WriteLine($"{value}={boolValue}"); + Assert.True(boolValue); + break; + default: + StringArray stringArray = (StringArray)valueArray.Fields[0]; + string stringValue = stringArray.GetString(i); + OutputHelper?.WriteLine($"{value}={stringValue}"); + Assert.NotNull(stringValue); + break; + } + } + + // Test the unhandled info codes. + List unhandledCodes = new List() + { + AdbcInfoCode.VendorArrowVersion, + AdbcInfoCode.VendorSubstrait, + AdbcInfoCode.VendorSubstraitMaxVersion + }; + using IArrowArrayStream stream2 = adbcConnection.GetInfo(unhandledCodes); + + recordBatch = await stream2.ReadNextRecordBatchAsync(); + infoNameArray = (UInt32Array)recordBatch.Column("info_name"); + + List unexpectedValues = new List() + { + "VendorArrowVersion", + "VendorSubstrait", + "VendorSubstraitMaxVersion" + }; + for (int i = 0; i < infoNameArray.Length; i++) + { + AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i); + DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value"); + + Assert.Contains(value.ToString(), unexpectedValues); + switch (value) + { + case AdbcInfoCode.VendorSql: + BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1]; + Assert.Null(booleanArray.GetValue(i)); + break; + default: + StringArray stringArray = (StringArray)valueArray.Fields[0]; + Assert.Null(stringArray.GetString(i)); + break; + } + } + } + + /// + /// Validates if the driver can call GetObjects with GetObjectsDepth as Catalogs with CatalogPattern as a pattern. + /// + /// + [SkippableFact, Order(3)] + public abstract void CanGetObjectsCatalogs(string? pattern); + + /// + /// Validates if the driver can call GetObjects with GetObjectsDepth as Catalogs with CatalogPattern as a pattern. + /// + /// + protected void GetObjectsCatalogsTest(string? pattern) + { + string? catalogName = TestConfiguration.Metadata.Catalog; + string? schemaName = TestConfiguration.Metadata.Schema; + + using IArrowArrayStream stream = Connection.GetObjects( + depth: AdbcConnection.GetObjectsDepth.Catalogs, + catalogPattern: pattern, + dbSchemaPattern: null, + tableNamePattern: null, + tableTypes: DefaultTableTypes, + columnNamePattern: null); + + using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); + AdbcCatalog? catalog = catalogs.Where((catalog) => string.Equals(catalog.Name, catalogName)).FirstOrDefault(); + + Assert.True(pattern == string.Empty && catalog == null || catalog != null, "catalog should not be null"); + } + + + /// + /// Validates if the driver can call GetObjects with GetObjectsDepth as DbSchemas with DbSchemaName as a pattern. + /// + [SkippableFact, Order(4)] + public abstract void CanGetObjectsDbSchemas(string dbSchemaPattern); + + /// + /// Validates if the driver can call GetObjects with GetObjectsDepth as DbSchemas with DbSchemaName as a pattern. + /// + protected void GetObjectsDbSchemasTest(string dbSchemaPattern) + { + // need to add the database + string? databaseName = TestConfiguration.Metadata.Catalog; + string? schemaName = TestConfiguration.Metadata.Schema; + + using IArrowArrayStream stream = Connection.GetObjects( + depth: AdbcConnection.GetObjectsDepth.DbSchemas, + catalogPattern: databaseName, + dbSchemaPattern: dbSchemaPattern, + tableNamePattern: null, + tableTypes: DefaultTableTypes, + columnNamePattern: null); + + using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); + + List? dbSchemas = catalogs + .Where(c => string.Equals(c.Name, databaseName)) + .Select(c => c.DbSchemas) + .FirstOrDefault(); + AdbcDbSchema? dbSchema = dbSchemas?.Where((dbSchema) => string.Equals(dbSchema.Name, schemaName)).FirstOrDefault(); + + Assert.True(dbSchema != null, "dbSchema should not be null"); + } + + /// + /// Validates if the driver can call GetObjects with GetObjectsDepth as Tables with TableName as a pattern. + /// + [SkippableFact, Order(5)] + public abstract void CanGetObjectsTables(string tableNamePattern); + + /// + /// Validates if the driver can call GetObjects with GetObjectsDepth as Tables with TableName as a pattern. + /// + protected void GetObjectsTablesTest(string tableNamePattern) + { + // need to add the database + string? databaseName = TestConfiguration.Metadata.Catalog; + string? schemaName = TestConfiguration.Metadata.Schema; + string? tableName = TestConfiguration.Metadata.Table; + + using IArrowArrayStream stream = Connection.GetObjects( + depth: AdbcConnection.GetObjectsDepth.Tables, + catalogPattern: databaseName, + dbSchemaPattern: schemaName, + tableNamePattern: tableNamePattern, + tableTypes: DefaultTableTypes, + columnNamePattern: null); + + using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); + + List? tables = catalogs + .Where(c => string.Equals(c.Name, databaseName)) + .Select(c => c.DbSchemas) + .FirstOrDefault() + ?.Where(s => string.Equals(s.Name, schemaName)) + .Select(s => s.Tables) + .FirstOrDefault(); + + AdbcTable? table = tables?.Where((table) => string.Equals(table.Name, tableName)).FirstOrDefault(); + Assert.True(table != null, "table should not be null"); + // TODO: Determine why this is returned blank. + //Assert.Equal("TABLE", table.Type); + } + + /// + /// Validates if the driver can call GetObjects for GetObjectsDepth as All. + /// + [SkippableFact, Order(6)] + public void CanGetObjectsAll() + { + // need to add the database + string? databaseName = TestConfiguration.Metadata.Catalog; + string? schemaName = TestConfiguration.Metadata.Schema; + string? tableName = TestConfiguration.Metadata.Table; + string? columnName = null; + + using IArrowArrayStream stream = Connection.GetObjects( + depth: AdbcConnection.GetObjectsDepth.All, + catalogPattern: databaseName, + dbSchemaPattern: schemaName, + tableNamePattern: tableName, + tableTypes: DefaultTableTypes, + columnNamePattern: columnName); + + using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); + AdbcTable? table = catalogs + .Where(c => string.Equals(c.Name, databaseName)) + .Select(c => c.DbSchemas) + .FirstOrDefault() + ?.Where(s => string.Equals(s.Name, schemaName)) + .Select(s => s.Tables) + .FirstOrDefault() + ?.Where(t => string.Equals(t.Name, tableName)) + .FirstOrDefault(); + + Assert.True(table != null, "table should not be null"); + // TODO: Determine why this is returned blank. + //Assert.Equal("TABLE", table.Type); + List? columns = table.Columns; + + Assert.True(columns != null, "Columns cannot be null"); + Assert.Equal(TestConfiguration.Metadata.ExpectedColumnCount, columns.Count); + + for (int i = 0; i < columns.Count; i++) + { + // Verify column metadata is returned/consistent. + AdbcColumn column = columns[i]; + Assert.Equal(i + 1, column.OrdinalPosition); + Assert.False(string.IsNullOrEmpty(column.Name)); + Assert.False(string.IsNullOrEmpty(column.XdbcTypeName)); + Assert.False(Regex.IsMatch(column.XdbcTypeName, @"[_,\d\<\>\(\)]", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant), + "Unexpected character found in field XdbcTypeName"); + + var supportedTypes = Enum.GetValues(typeof(SupportedSparkDataType)).Cast(); + Assert.Contains((SupportedSparkDataType)column.XdbcSqlDataType!, supportedTypes); + Assert.Equal(column.XdbcDataType, column.XdbcSqlDataType); + + Assert.NotNull(column.XdbcDataType); + Assert.Contains((SupportedSparkDataType)column.XdbcDataType!, supportedTypes); + + HashSet typesHaveColumnSize = new() + { + (short)SupportedSparkDataType.DECIMAL, + (short)SupportedSparkDataType.NUMERIC, + (short)SupportedSparkDataType.CHAR, + (short)SupportedSparkDataType.VARCHAR, + }; + HashSet typesHaveDecimalDigits = new() + { + (short)SupportedSparkDataType.DECIMAL, + (short)SupportedSparkDataType.NUMERIC, + }; + + bool typeHasColumnSize = typesHaveColumnSize.Contains(column.XdbcDataType.Value); + Assert.Equal(column.XdbcColumnSize.HasValue, typeHasColumnSize); + + bool typeHasDecimalDigits = typesHaveDecimalDigits.Contains(column.XdbcDataType.Value); + Assert.Equal(column.XdbcDecimalDigits.HasValue, typeHasDecimalDigits); + + Assert.False(string.IsNullOrEmpty(column.Remarks)); + + Assert.NotNull(column.XdbcColumnDef); + + Assert.NotNull(column.XdbcNullable); + Assert.Contains(new short[] { 1, 0 }, i => i == column.XdbcNullable); + + Assert.NotNull(column.XdbcIsNullable); + Assert.Contains(new string[] { "YES", "NO" }, i => i.Equals(column.XdbcIsNullable)); + + Assert.NotNull(column.XdbcIsAutoIncrement); + + Assert.Null(column.XdbcCharOctetLength); + Assert.Null(column.XdbcDatetimeSub); + Assert.Null(column.XdbcNumPrecRadix); + Assert.Null(column.XdbcScopeCatalog); + Assert.Null(column.XdbcScopeSchema); + Assert.Null(column.XdbcScopeTable); + } + } + + /// + /// Validates if the driver can call GetObjects with GetObjectsDepth as Tables with TableName as a Special Character. + /// + [SkippableTheory, Order(7)] + [InlineData("MyIdentifier")] + [InlineData("ONE")] + [InlineData("mYiDentifier")] + [InlineData("3rd_identifier")] + // Note: Tables in 'hive_metastore' only support ASCII alphabetic, numeric and underscore. + public void CanGetObjectsTablesWithSpecialCharacter(string tableName) + { + string catalogName = TestConfiguration.Metadata.Catalog; + string schemaPrefix = Guid.NewGuid().ToString().Replace("-", ""); + using TemporarySchema schema = TemporarySchema.NewTemporarySchemaAsync(catalogName, Statement).Result; + string schemaName = schema.SchemaName; + string catalogFormatted = string.IsNullOrEmpty(catalogName) ? string.Empty : DelimitIdentifier(catalogName) + "."; + string fullTableName = $"{catalogFormatted}{DelimitIdentifier(schemaName)}.{DelimitIdentifier(tableName)}"; + using TemporaryTable temporaryTable = TemporaryTable.NewTemporaryTableAsync(Statement, fullTableName, $"CREATE TABLE IF NOT EXISTS {fullTableName} (INDEX INT)").Result; + + using IArrowArrayStream stream = Connection.GetObjects( + depth: AdbcConnection.GetObjectsDepth.Tables, + catalogPattern: catalogName, + dbSchemaPattern: schemaName, + tableNamePattern: tableName, + tableTypes: DefaultTableTypes, + columnNamePattern: null); + + using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); + + List? tables = catalogs + .Where(c => string.Equals(c.Name, catalogName)) + .Select(c => c.DbSchemas) + .FirstOrDefault() + ?.Where(s => string.Equals(s.Name, schemaName)) + .Select(s => s.Tables) + .FirstOrDefault(); + + AdbcTable? table = tables?.FirstOrDefault(); + + Assert.True(table != null, "table should not be null"); + Assert.Equal(tableName, table.Name, true); + } + + /// + /// Validates if the driver can call GetTableSchema. + /// + [SkippableFact, Order(8)] + public void CanGetTableSchema() + { + AdbcConnection adbcConnection = NewConnection(); + + string? catalogName = TestConfiguration.Metadata.Catalog; + string? schemaName = TestConfiguration.Metadata.Schema; + string tableName = TestConfiguration.Metadata.Table!; + + Schema schema = adbcConnection.GetTableSchema(catalogName, schemaName, tableName); + + int numberOfFields = schema.FieldsList.Count; + + Assert.Equal(TestConfiguration.Metadata.ExpectedColumnCount, numberOfFields); + } + + /// + /// Validates if the driver can call GetTableTypes. + /// + [SkippableFact, Order(9)] + public async Task CanGetTableTypes() + { + AdbcConnection adbcConnection = NewConnection(); + + using IArrowArrayStream arrowArrayStream = adbcConnection.GetTableTypes(); + + RecordBatch recordBatch = await arrowArrayStream.ReadNextRecordBatchAsync(); + + StringArray stringArray = (StringArray)recordBatch.Column("table_type"); + + List known_types = new List + { + "TABLE", "VIEW" + }; + + int results = 0; + + for (int i = 0; i < stringArray.Length; i++) + { + string value = stringArray.GetString(i); + + if (known_types.Contains(value)) + { + results++; + } + } + + Assert.Equal(known_types.Count, results); + } + + /// + /// Validates if the driver can connect to a live server and + /// parse the results. + /// + [SkippableTheory, Order(10)] + [InlineData(0.1)] + [InlineData(0.25)] + [InlineData(1.0)] + [InlineData(2.0)] + [InlineData(null)] + public void CanExecuteQuery(double? batchSizeFactor) + { + // Ensure all records can be retrieved, independent of the batch size. + TConfig testConfiguration = (TConfig)TestConfiguration.Clone(); + long expectedResultCount = testConfiguration.ExpectedResultsCount; + long nonZeroExpectedResultCount = expectedResultCount == 0 ? 1 : expectedResultCount; + testConfiguration.BatchSize = batchSizeFactor != null ? ((long)(nonZeroExpectedResultCount * batchSizeFactor)).ToString() : string.Empty; + OutputHelper?.WriteLine($"BatchSize: {testConfiguration.BatchSize}. ExpectedResultCount: {expectedResultCount}"); + + using AdbcConnection adbcConnection = NewConnection(testConfiguration); + + using AdbcStatement statement = adbcConnection.CreateStatement(); + statement.SqlQuery = TestConfiguration.Query; + OutputHelper?.WriteLine(statement.SqlQuery); + + QueryResult queryResult = statement.ExecuteQuery(); + + Tests.DriverTests.CanExecuteQuery(queryResult, TestConfiguration.ExpectedResultsCount); + } + + /// + /// Validates if the driver can connect to a live server and + /// parse the results using the asynchronous methods. + /// + [SkippableFact, Order(11)] + public async Task CanExecuteQueryAsync() + { + using AdbcConnection adbcConnection = NewConnection(); + using AdbcStatement statement = adbcConnection.CreateStatement(); + + statement.SqlQuery = TestConfiguration.Query; + QueryResult queryResult = await statement.ExecuteQueryAsync(); + + await Tests.DriverTests.CanExecuteQueryAsync(queryResult, TestConfiguration.ExpectedResultsCount); + } + + /// + /// Validates if the driver can connect to a live server and + /// perform and update asynchronously. + /// + [SkippableFact, Order(12)] + public async Task CanExecuteUpdateAsync() + { + using AdbcConnection adbcConnection = NewConnection(); + using AdbcStatement statement = adbcConnection.CreateStatement(); + using TemporaryTable temporaryTable = await NewTemporaryTableAsync(statement, "INDEX INT"); + + statement.SqlQuery = GetInsertStatement(temporaryTable.TableName, "INDEX", "1"); + UpdateResult updateResult = await statement.ExecuteUpdateAsync(); + + if (ValidateAffectedRows) Assert.Equal(1, updateResult.AffectedRows); + } + + [SkippableFact, Order(13)] + public void CanDetectInvalidAuthentication() + { + AdbcDriver driver = NewDriver; + Assert.NotNull(driver); + Dictionary parameters = GetDriverParameters(TestConfiguration); + + bool hasToken = parameters.TryGetValue(SparkParameters.Token, out var token) && !string.IsNullOrEmpty(token); + bool hasUsername = parameters.TryGetValue(AdbcOptions.Username, out var username) && !string.IsNullOrEmpty(username); + bool hasPassword = parameters.TryGetValue(AdbcOptions.Password, out var password) && !string.IsNullOrEmpty(password); + if (hasToken) + { + parameters[SparkParameters.Token] = "invalid-token"; + } + else if (hasUsername && hasPassword) + { + parameters[AdbcOptions.Password] = "invalid-password"; + } + else + { + Assert.Fail($"Unexpected configuration. Must provide '{SparkParameters.Token}' or '{AdbcOptions.Username}' and '{AdbcOptions.Password}'."); + } + + AdbcDatabase database = driver.Open(parameters); + AggregateException exception = Assert.ThrowsAny(() => database.Connect(parameters)); + OutputHelper?.WriteLine(exception.Message); + } + + [SkippableFact, Order(14)] + public void CanDetectInvalidServer() + { + AdbcDriver driver = NewDriver; + Assert.NotNull(driver); + Dictionary parameters = GetDriverParameters(TestConfiguration); + + bool hasUri = parameters.TryGetValue(AdbcOptions.Uri, out var uri) && !string.IsNullOrEmpty(uri); + bool hasHostName = parameters.TryGetValue(SparkParameters.HostName, out var hostName) && !string.IsNullOrEmpty(hostName); + if (hasUri) + { + parameters[AdbcOptions.Uri] = "http://unknownhost.azure.com/cliservice"; + } + else if (hasHostName) + { + parameters[SparkParameters.HostName] = "unknownhost.azure.com"; + } + else + { + Assert.Fail($"Unexpected configuration. Must provide '{AdbcOptions.Uri}' or '{SparkParameters.HostName}'."); + } + + AdbcDatabase database = driver.Open(parameters); + AggregateException exception = Assert.ThrowsAny(() => database.Connect(parameters)); + OutputHelper?.WriteLine(exception.Message); + } + + /// + /// Validates if the driver can connect to a live server and + /// parse the results using the asynchronous methods. + /// + [SkippableFact, Order(15)] + public async Task CanExecuteQueryAsyncEmptyResult() + { + using AdbcConnection adbcConnection = NewConnection(); + using AdbcStatement statement = adbcConnection.CreateStatement(); + + statement.SqlQuery = $"SELECT * from {TestConfiguration.Metadata.Table} WHERE FALSE"; + QueryResult queryResult = await statement.ExecuteQueryAsync(); + + await Tests.DriverTests.CanExecuteQueryAsync(queryResult, 0); + } + } +} diff --git a/csharp/test/Drivers/Apache/Common/NumericValueTests.cs b/csharp/test/Drivers/Apache/Common/NumericValueTests.cs new file mode 100644 index 0000000000..99ad341e3d --- /dev/null +++ b/csharp/test/Drivers/Apache/Common/NumericValueTests.cs @@ -0,0 +1,274 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Data.SqlTypes; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common +{ + // TODO: When supported, use prepared statements instead of SQL string literals + // Which will better test how the driver handles values sent/received + + public abstract class NumericValueTests : TestBase + where TConfig : TestConfiguration + where TEnv : HiveServer2TestEnvironment + { + /// + /// Validates that specific numeric values can be inserted, retrieved and targeted correctly + /// + public NumericValueTests(ITestOutputHelper output, TestEnvironment.Factory testEnvFactory) + : base(output, testEnvFactory) { } + + /// + /// Validates if driver can send and receive specific Integer values correctly + /// + [SkippableTheory] + [InlineData(-1)] + [InlineData(0)] + [InlineData(1)] + public async Task TestIntegerSanity(int value) + { + string columnName = "INTTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} INT", columnName)); + await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, value); + } + + /// + /// Validates if driver can handle the largest / smallest numbers + /// + [SkippableTheory] + [InlineData(int.MaxValue)] + [InlineData(int.MinValue)] + public async Task TestIntegerMinMax(int value) + { + string columnName = "INTTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} INT", columnName)); + await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, value); + } + + /// + /// Validates if driver can handle the largest / smallest numbers + /// + [SkippableTheory] + [InlineData(long.MaxValue)] + [InlineData(long.MinValue)] + public async Task TestLongMinMax(long value) + { + string columnName = "BIGINTTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} BIGINT", columnName)); + await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, value); + } + + /// + /// Validates if driver can handle the largest / smallest numbers + /// + [SkippableTheory] + [InlineData(short.MaxValue)] + [InlineData(short.MinValue)] + public async Task TestSmallIntMinMax(short value) + { + string columnName = "SMALLINTTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} SMALLINT", columnName)); + await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, value); + } + + /// + /// Validates if driver can handle the largest / smallest numbers + /// + [SkippableTheory] + [InlineData(sbyte.MaxValue)] + [InlineData(sbyte.MinValue)] + public async Task TestTinyIntMinMax(sbyte value) + { + string columnName = "TINYINTTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} TINYINT", columnName)); + await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, value); + } + + /// + /// Validates if driver can handle smaller Number type correctly + /// + [SkippableTheory] + [InlineData("-1")] + [InlineData("0")] + [InlineData("1")] + [InlineData("99")] + [InlineData("-99")] + public async Task TestSmallNumberRange(string value) + { + string columnName = "SMALLNUMBER"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(2,0)", columnName)); + object? expectedValue = TestEnvironment.GetValueForProtocolVersion(value, new SqlDecimal(double.Parse(value))); + await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, expectedValue); + } + + /// + /// Validates if driver correctly errors out when the values exceed the column's limit + /// + [SkippableTheory] + [InlineData(-100)] + [InlineData(100)] + [InlineData(int.MaxValue)] + [InlineData(int.MinValue)] + public async Task TestSmallNumberRangeOverlimit(int value) + { + string columnName = "SMALLNUMBER"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(2,0)", columnName)); + await Assert.ThrowsAsync( + async () => await ValidateInsertSelectDeleteSingleValueAsync( + table.TableName, + columnName, TestEnvironment.GetValueForProtocolVersion(value.ToString(), new SqlDecimal(value)))); + } + + /// + /// Validates if driver can handle a large scale Number type correctly + /// + [SkippableTheory] + [InlineData("0E-37")] + [InlineData("-2.0030000000000000000000000000000000000")] + [InlineData("4.8500000000000000000000000000000000000")] + [InlineData("1E-37")] + [InlineData("9.5545204502636499875576383003668916798")] + public async Task TestLargeScaleNumberRange(string value) + { + string columnName = "LARGESCALENUMBER"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,37)", columnName)); + await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, TestEnvironment.GetValueForProtocolVersion(value, new SqlDecimal(double.Parse(value)))); + } + + /// + /// Validates if driver can error handle when input goes beyond a large scale Number type + /// + [SkippableTheory] + [InlineData("-10")] + [InlineData("10")] + [InlineData("99999999999999999999999999999999999999")] + [InlineData("-99999999999999999999999999999999999999")] + public async Task TestLargeScaleNumberOverlimit(string value) + { + string columnName = "LARGESCALENUMBER"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,37)", columnName)); + await Assert.ThrowsAsync(async () => await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, SqlDecimal.Parse(value))); + } + + /// + /// Validates if driver can handle a small scale Number type correctly + /// + [SkippableTheory] + [InlineData("0.00")] + [InlineData("4.85")] + [InlineData("-999999999999999999999999999999999999.99")] + [InlineData("999999999999999999999999999999999999.99")] + public async Task TestSmallScaleNumberRange(string value) + { + string columnName = "SMALLSCALENUMBER"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,2)", columnName)); + await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, TestEnvironment.GetValueForProtocolVersion(value, SqlDecimal.Parse(value))); + } + + /// + /// Validates if driver can error handle when an insert goes beyond a small scale Number type correctly + /// + [SkippableTheory] + [InlineData("-99999999999999999999999999999999999999")] + [InlineData("99999999999999999999999999999999999999")] + public async Task TestSmallScaleNumberOverlimit(string value) + { + string columnName = "SMALLSCALENUMBER"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,2)", columnName)); + await Assert.ThrowsAsync(async () => await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, SqlDecimal.Parse(value))); + } + + + /// + /// Tests that decimals are rounded as expected. + /// Snowflake allows inserts of scales beyond the data type size, but storage of value will round it up or down + /// + [SkippableTheory] + [InlineData(2.467, 2.47)] + [InlineData(-672.613, -672.61)] + public async Task TestRoundingNumbers(decimal input, decimal output) + { + string columnName = "SMALLSCALENUMBER"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,2)", columnName)); + SqlDecimal value = new SqlDecimal(input); + SqlDecimal returned = new SqlDecimal(output); + await InsertSingleValueAsync(table.TableName, columnName, value.ToString()); + await SelectAndValidateValuesAsync(table.TableName, columnName, TestEnvironment.GetValueForProtocolVersion(output.ToString(), returned), 1); + string whereClause = GetWhereClause(columnName, returned); + if (SupportsDelete) await DeleteFromTableAsync(table.TableName, whereClause, 1); + } + + /// + /// Validates if driver can handle floating point number type correctly + /// + [SkippableTheory] + [InlineData(0)] + [InlineData(0.2)] + [InlineData(15e-03)] + [InlineData(1.234E+2)] + [InlineData(double.NegativeInfinity)] + [InlineData(double.PositiveInfinity)] + [InlineData(double.NaN)] + [InlineData(double.MinValue)] + [InlineData(double.MaxValue)] + public async Task TestDoubleValuesInsertSelectDelete(double value) + { + string columnName = "DOUBLETYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DOUBLE", columnName)); + string valueString = ConvertDoubleToString(value); + await InsertSingleValueAsync(table.TableName, columnName, valueString); + await SelectAndValidateValuesAsync(table.TableName, columnName, value, 1); + string whereClause = GetWhereClause(columnName, value); + if (SupportsDelete) await DeleteFromTableAsync(table.TableName, whereClause, 1); + } + + /// + /// Validates if driver can handle floating point number type correctly + /// + [SkippableTheory] + [InlineData(0)] + [InlineData(25)] + [InlineData(float.NegativeInfinity)] + [InlineData(float.PositiveInfinity)] + [InlineData(float.NaN)] + // TODO: Solve server issue when non-integer float value is used in where clause. + //[InlineData(25.1)] + //[InlineData(0.2)] + //[InlineData(15e-03)] + //[InlineData(1.234E+2)] + //[InlineData(float.MinValue)] + //[InlineData(float.MaxValue)] + public async Task TestFloatValuesInsertSelectDelete(float value) + { + string columnName = "FLOATTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} FLOAT", columnName)); + string valueString = ConvertFloatToString(value); + await InsertSingleValueAsync(table.TableName, columnName, valueString); + object doubleValue = (double)value; + // Spark over HTTP returns float as double whereas Spark on Databricks returns float. + object floatValue = TestEnvironment.DataTypeConversion.HasFlag(DataTypeConversion.Scalar) ? value : doubleValue; + await SelectAndValidateValuesAsync(table.TableName, columnName, floatValue, 1); + string whereClause = GetWhereClause(columnName, value); + if (SupportsDelete) await DeleteFromTableAsync(table.TableName, whereClause, 1); + } + } +} diff --git a/csharp/test/Drivers/Apache/Common/StatementTests.cs b/csharp/test/Drivers/Apache/Common/StatementTests.cs new file mode 100644 index 0000000000..b793b7686c --- /dev/null +++ b/csharp/test/Drivers/Apache/Common/StatementTests.cs @@ -0,0 +1,237 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tests.Xunit; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common +{ + /// + /// Class for testing the Snowflake ADBC driver connection tests. + /// + /// + /// Tests are ordered to ensure data is created for the other + /// queries to run. + /// + [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] + public abstract class StatementTests : TestBase + where TConfig : ApacheTestConfiguration + where TEnv : HiveServer2TestEnvironment + { + private static List DefaultTableTypes => ["TABLE", "VIEW"]; + + public StatementTests(ITestOutputHelper? outputHelper, TestEnvironment.Factory testEnvFactory) + : base(outputHelper, testEnvFactory) + { + Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); + } + + /// + /// Validates if the SetOption handle valid/invalid data correctly for the PollTime option. + /// + [SkippableTheory] + [InlineData("-1", true)] + [InlineData("zero", true)] + [InlineData("-2147483648", true)] + [InlineData("2147483648", true)] + [InlineData("0")] + [InlineData("1")] + [InlineData("2147483647")] + public void CanSetOptionPollTime(string value, bool throws = false) + { + var testConfiguration = TestConfiguration.Clone() as TConfig; + testConfiguration!.PollTimeMilliseconds = value; + if (throws) + { + Assert.Throws(() => NewConnection(testConfiguration).CreateStatement()); + } + + AdbcStatement statement = NewConnection().CreateStatement(); + if (throws) + { + Assert.Throws(() => statement.SetOption(ApacheParameters.PollTimeMilliseconds, value)); + } + else + { + statement.SetOption(ApacheParameters.PollTimeMilliseconds, value); + } + } + + /// + /// Validates if the SetOption handle valid/invalid data correctly for the BatchSize option. + /// + [SkippableTheory] + [InlineData("-1", true)] + [InlineData("one", true)] + [InlineData("-2147483648", true)] + [InlineData("2147483648", false)] + [InlineData("9223372036854775807", false)] + [InlineData("9223372036854775808", true)] + [InlineData("0", true)] + [InlineData("1")] + [InlineData("2147483647")] + public void CanSetOptionBatchSize(string value, bool throws = false) + { + var testConfiguration = TestConfiguration.Clone() as TConfig; + testConfiguration!.BatchSize = value; + if (throws) + { + Assert.Throws(() => NewConnection(testConfiguration).CreateStatement()); + } + + AdbcStatement statement = NewConnection().CreateStatement(); + if (throws) + { + Assert.Throws(() => statement!.SetOption(ApacheParameters.BatchSize, value)); + } + else + { + statement.SetOption(ApacheParameters.BatchSize, value); + } + } + + /// + /// Validates if the SetOption handle valid/invalid data correctly for the QueryTimeout option. + /// + [SkippableTheory] + [InlineData("zero", true)] + [InlineData("-2147483648", true)] + [InlineData("2147483648", true)] + [InlineData("0", false)] + [InlineData("-1", true)] + [InlineData("1")] + [InlineData("2147483647")] + public void CanSetOptionQueryTimeout(string value, bool throws = false) + { + var testConfiguration = TestConfiguration.Clone() as TConfig; + testConfiguration!.QueryTimeoutSeconds = value; + if (throws) + { + Assert.Throws(() => NewConnection(testConfiguration).CreateStatement()); + } + + AdbcStatement statement = NewConnection().CreateStatement(); + if (throws) + { + Assert.Throws(() => statement.SetOption(ApacheParameters.QueryTimeoutSeconds, value)); + } + else + { + statement.SetOption(ApacheParameters.QueryTimeoutSeconds, value); + } + } + + /// + /// Queries the backend with various timeouts. + /// + /// + [SkippableTheory] + [ClassData(typeof(StatementTimeoutTestData))] + internal void StatementTimeoutTest(StatementWithExceptions statementWithExceptions) + { + TConfig testConfiguration = (TConfig)TestConfiguration.Clone(); + + if (statementWithExceptions.QueryTimeoutSeconds.HasValue) + testConfiguration.QueryTimeoutSeconds = statementWithExceptions.QueryTimeoutSeconds.Value.ToString(); + + if (!string.IsNullOrEmpty(statementWithExceptions.Query)) + testConfiguration.Query = statementWithExceptions.Query!; + + OutputHelper?.WriteLine($"QueryTimeoutSeconds: {testConfiguration.QueryTimeoutSeconds}. ShouldSucceed: {statementWithExceptions.ExceptionType == null}. Query: [{testConfiguration.Query}]"); + + try + { + AdbcStatement st = NewConnection(testConfiguration).CreateStatement(); + st.SqlQuery = testConfiguration.Query; + QueryResult qr = st.ExecuteQuery(); + + OutputHelper?.WriteLine($"QueryResultRowCount: {qr.RowCount}"); + } + catch (Exception ex) when (ApacheUtility.ContainsException(ex, statementWithExceptions.ExceptionType, out Exception? containedException)) + { + Assert.IsType(statementWithExceptions.ExceptionType!, containedException!); + } + } + + /// + /// Validates if the driver can execute update statements. + /// + [SkippableFact, Order(1)] + public async Task CanInteractUsingSetOptions() + { + const string columnName = "INDEX"; + Statement.SetOption(ApacheParameters.PollTimeMilliseconds, "100"); + Statement.SetOption(ApacheParameters.BatchSize, "10"); + using TemporaryTable temporaryTable = await NewTemporaryTableAsync(Statement, $"{columnName} INT"); + await ValidateInsertSelectDeleteSingleValueAsync(temporaryTable.TableName, columnName, 1); + } + } + + /// + /// Data type used for metadata timeout tests. + /// + internal class StatementWithExceptions + { + public StatementWithExceptions(int? queryTimeoutSeconds, string? query, Type? exceptionType) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + Query = query; + ExceptionType = exceptionType; + } + + /// + /// If null, uses the default timeout. + /// + public int? QueryTimeoutSeconds { get; } + + /// + /// If null, expected to succeed. + /// + public Type? ExceptionType { get; } + + /// + /// If null, uses the default TestConfiguration. + /// + public string? Query { get; } + } + + /// + /// Collection of for testing statement timeouts."/> + /// + internal class StatementTimeoutTestData : TheoryData + { + public StatementTimeoutTestData() + { + string longRunningQuery = "SELECT COUNT(*) AS total_count\nFROM (\n SELECT t1.id AS id1, t2.id AS id2\n FROM RANGE(1000000) t1\n CROSS JOIN RANGE(10000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0"; + + Add(new(0, null, null)); + Add(new(null, null, null)); + Add(new(1, null, typeof(TimeoutException))); + Add(new(5, null, null)); + Add(new(30, null, null)); + Add(new(5, longRunningQuery, typeof(TimeoutException))); + Add(new(null, longRunningQuery, typeof(TimeoutException))); + Add(new(0, longRunningQuery, null)); + } + } +} diff --git a/csharp/test/Drivers/Apache/Common/StringValueTests.cs b/csharp/test/Drivers/Apache/Common/StringValueTests.cs new file mode 100644 index 0000000000..e861f7ebfa --- /dev/null +++ b/csharp/test/Drivers/Apache/Common/StringValueTests.cs @@ -0,0 +1,129 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common +{ + // TODO: When supported, use prepared statements instead of SQL string literals + // Which will better test how the driver handles values sent/received + + /// + /// Validates that specific string and character values can be inserted, retrieved and targeted correctly + /// + public abstract class StringValueTests : TestBase + where TConfig : TestConfiguration + where TEnv : HiveServer2TestEnvironment + { + public StringValueTests(ITestOutputHelper output, TestEnvironment.Factory testEnvFactory) + : base(output, testEnvFactory) { } + + public static IEnumerable ByteArrayData(int size) + { + var rnd = new Random(); + byte[] bytes = new byte[size]; + rnd.NextBytes(bytes); + yield return new object[] { bytes }; + } + + /// + /// Validates if driver can send and receive specific String values correctly. + /// + [SkippableTheory] + [InlineData(null)] + [InlineData("")] + [InlineData("你好")] + [InlineData(" Leading and trailing spaces ")] + protected virtual async Task TestStringData(string? value) + { + string columnName = "STRINGTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "STRING")); + await ValidateInsertSelectDeleteSingleValueAsync( + table.TableName, + columnName, + value, + value != null ? QuoteValue(value) : value); + } + + /// + /// Validates if driver can send and receive specific VARCHAR values correctly. + /// + [SkippableTheory] + [InlineData(null)] + [InlineData("")] + [InlineData("你好")] + [InlineData(" Leading and trailing spaces ")] + protected virtual async Task TestVarcharData(string? value) + { + string columnName = "VARCHARTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "VARCHAR(100)")); + await ValidateInsertSelectDeleteSingleValueAsync( + table.TableName, + columnName, + value, + value != null ? QuoteValue(value) : value); + } + + /// + /// Validates if driver can send and receive specific VARCHAR values correctly. + /// + [SkippableTheory] + [InlineData(null)] + [InlineData("")] + [InlineData("你好")] + [InlineData(" Leading and trailing spaces ")] + protected virtual async Task TestCharData(string? value) + { + string columnName = "CHARTYPE"; + int fieldLength = 100; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, $"CHAR({fieldLength})")); + + string? formattedValue = value != null ? QuoteValue(value.PadRight(fieldLength)) : value; + string? paddedValue = value != null ? value.PadRight(fieldLength) : value; + + await InsertSingleValueAsync(table.TableName, columnName, formattedValue); + await SelectAndValidateValuesAsync(table.TableName, columnName, paddedValue, 1, formattedValue); + string whereClause = GetWhereClause(columnName, formattedValue ?? paddedValue); + if (SupportsDelete) await DeleteFromTableAsync(table.TableName, whereClause, 1); + } + + /// + /// Validates if driver fails to insert invalid length of VARCHAR value. + /// + [SkippableTheory] + [InlineData("String whose length is too long for VARCHAR(10).", new string[] { "Exceeds", "length limitation: 10" }, null)] + protected virtual async Task TestVarcharExceptionData(string value, string[] expectedTexts, string? expectedSqlState) + { + string columnName = "VARCHARTYPE"; + using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "VARCHAR(10)")); + AdbcException exception = await Assert.ThrowsAsync(async () => await ValidateInsertSelectDeleteSingleValueAsync( + table.TableName, + columnName, + value, + value != null ? QuoteValue(value) : value)); + + AssertContainsAll(expectedTexts, exception.Message); + Assert.Equal(expectedSqlState, exception.SqlState); + } + } +} diff --git a/csharp/test/Drivers/Apache/Hive2/DecimalUtilityTests.cs b/csharp/test/Drivers/Apache/Hive2/DecimalUtilityTests.cs new file mode 100644 index 0000000000..467317c40a --- /dev/null +++ b/csharp/test/Drivers/Apache/Hive2/DecimalUtilityTests.cs @@ -0,0 +1,182 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Buffers.Text; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Globalization; +using System.Text; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2 +{ + /// + /// Class for testing the Decimal Utilities tests. + /// + public class DecimalUtilityTests(ITestOutputHelper outputHelper) + { + private readonly ITestOutputHelper _outputHelper = outputHelper; + + [SkippableTheory] + [MemberData(nameof(Decimal128Data))] + public void TestCanConvertDecimal(string stringValue, int precision, int scale, int byteWidth, byte[] expected, SqlDecimal? expectedDecimal = default) + { + ReadOnlySpan value = Encoding.UTF8.GetBytes(stringValue); + byte[] actual = new byte[byteWidth]; + DecimalUtility.GetBytes(value, precision, scale, byteWidth, actual); + Assert.Equal(expected, actual); + Assert.Equal(0, byteWidth % 4); + int[] buffer = new int[byteWidth / 4]; + for (int i = 0; i < buffer.Length; i++) + { + buffer[i] = BitConverter.ToInt32(actual, i * sizeof(int)); + } + SqlDecimal actualDecimal = GetSqlDecimal128(actual, 0, precision, scale); + if (expectedDecimal != null) Assert.Equal(expectedDecimal, actualDecimal); + } + + [Fact(Skip = "Run manually to confirm equivalent performance")] + public void TestConvertDecimalPerformance() + { + Stopwatch stopwatch = new(); + + int testCount = 1000000; + ReadOnlySpan testValue = "99999999999999999999999999999999999999"u8; + string testValueString = "99999999999999999999999999999999999999"; + int byteWidth = 16; + byte[] buffer = new byte[byteWidth]; + Decimal128Array.Builder builder = new(new Types.Decimal128Type(38, 0)); + stopwatch.Restart(); + for (int i = 0; i < testCount; i++) + { + if (Utf8Parser.TryParse(testValue, out decimal actualDecimal, out _, standardFormat: 'E')) + { + builder.Append(new SqlDecimal(actualDecimal)); + } + else + { + builder.Append(testValueString); + } + } + stopwatch.Stop(); + _outputHelper.WriteLine($"Decimal128Builder.Append: {testCount} iterations took {stopwatch.ElapsedMilliseconds} elapsed milliseconds"); + + builder = new(new Types.Decimal128Type(38, 0)); + stopwatch.Restart(); + for (int i = 0; i < testCount; i++) + { + DecimalUtility.GetBytes(testValue, 38, 0, byteWidth, buffer); + builder.Append(buffer); + } + stopwatch.Stop(); + _outputHelper.WriteLine($"DecimalUtility.GetBytes: {testCount} iterations took {stopwatch.ElapsedMilliseconds} elapsed milliseconds"); + } + + public static IEnumerable Decimal128Data() + { + yield return new object[] { "0", 1, 0, 16, new byte[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0) }; + + yield return new object[] { "1", 1, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "1E0", 1, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "10e-1", 1, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "0.1e1", 1, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + + yield return new object[] { "12", 2, 0, 16, new byte[] { 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(12) }; + yield return new object[] { "12E0", 2, 0, 16, new byte[] { 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(12) }; + yield return new object[] { "120e-1", 2, 0, 16, new byte[] { 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(12) }; + yield return new object[] { "1.2e1", 2, 0, 16, new byte[] { 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(12) }; + + yield return new object[] { "99999999999999999999999999999999999999", 38, 0, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + yield return new object[] { "99999999999999999999999999999999999999E0", 38, 0, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + yield return new object[] { "999999999999999999999999999999999999990e-1", 38, 0, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + yield return new object[] { "0.99999999999999999999999999999999999999e38", 38, 0, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + + yield return new object[] { "-1", 1, 0, 16, new byte[] { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-1) }; + yield return new object[] { "-1E0", 1, 0, 16, new byte[] { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-1) }; + yield return new object[] { "-10e-1", 1, 0, 16, new byte[] { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-1) }; + yield return new object[] { "-0.1e1", 1, 0, 16, new byte[] { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-1) }; + + yield return new object[] { "-12", 2, 0, 16, new byte[] { 244, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-12) }; + yield return new object[] { "-12E0", 2, 0, 16, new byte[] { 244, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-12) }; + yield return new object[] { "-120e-1", 2, 0, 16, new byte[] { 244, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-12) }; + yield return new object[] { "-1.2e1", 2, 0, 16, new byte[] { 244, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-12) }; + + yield return new object[] { "1", 38, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "1E0", 38, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "10e-1", 38, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "0.1e1", 38, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + + yield return new object[] { "1", 3, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "1E0", 3, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "10e-1", 3, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "0.1e1", 3, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + + yield return new object[] { "1", 38, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "1E0", 38, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "10e-1", 38, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "0.1e1", 38, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + + yield return new object[] { "0.1", 38, 1, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "0.1E0", 38, 1, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "1e-1", 38, 1, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "0.01e1", 38, 1, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + + yield return new object[] { "0.1", 38, 3, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "0.1E0", 38, 3, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "1e-1", 38, 3, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "0.01e1", 38, 3, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + + yield return new object[] { "-0.1", 38, 3, 16, new byte[] { 156, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-0.1) }; + yield return new object[] { "-0.1E0", 38, 3, 16, new byte[] { 156, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-0.1) }; + yield return new object[] { "-1e-1", 38, 3, 16, new byte[] { 156, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-0.1) }; + yield return new object[] { "-0.01e1", 38, 3, 16, new byte[] { 156, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-0.1) }; + + yield return new object[] { "0.99999999999999999999999999999999999999", 38, 38, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + yield return new object[] { "0.99999999999999999999999999999999999999E0", 38, 38, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + yield return new object[] { "9.99999999999999999999999999999999999990e-1", 38, 38, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + yield return new object[] { "0.0000000000000000000000000000000000000099999999999999999999999999999999999999e38", 38, 38, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + } + + private static SqlDecimal GetSqlDecimal128(in byte[] valueBuffer, int index, int precision, int scale) + { + const int byteWidth = 16; + const int intWidth = byteWidth / 4; + const int longWidth = byteWidth / 8; + + byte mostSignificantByte = valueBuffer.AsSpan()[(index + 1) * byteWidth - 1]; + bool isPositive = (mostSignificantByte & 0x80) == 0; + + if (isPositive) + { + ReadOnlySpan value = valueBuffer.AsSpan().CastTo().Slice(index * intWidth, intWidth); + return new SqlDecimal((byte)precision, (byte)scale, true, value[0], value[1], value[2], value[3]); + } + else + { + ReadOnlySpan value = valueBuffer.AsSpan().CastTo().Slice(index * longWidth, longWidth); + long data1 = -value[0]; + long data2 = data1 == 0 ? -value[1] : ~value[1]; + + return new SqlDecimal((byte)precision, (byte)scale, false, (int)(data1 & 0xffffffff), (int)(data1 >> 32), (int)(data2 & 0xffffffff), (int)(data2 >> 32)); + } + } + } +} diff --git a/csharp/test/Drivers/Apache/Hive2/HiveServer2ParametersTest.cs b/csharp/test/Drivers/Apache/Hive2/HiveServer2ParametersTest.cs new file mode 100644 index 0000000000..992e5ffb1d --- /dev/null +++ b/csharp/test/Drivers/Apache/Hive2/HiveServer2ParametersTest.cs @@ -0,0 +1,93 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Xunit; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2 +{ + public class HiveServer2ParametersTest + { + [SkippableTheory] + [MemberData(nameof(GetParametersDataTypeConvTestData))] + internal void TestParametersDataTypeConvParse(string? dataTypeConversion, DataTypeConversion expected, Type? exceptionType = default) + { + if (exceptionType == default) + Assert.Equal(expected, DataTypeConversionParser.Parse(dataTypeConversion)); + else + Assert.Throws(exceptionType, () => DataTypeConversionParser.Parse(dataTypeConversion)); + } + + [SkippableTheory] + [MemberData(nameof(GetParametersTlsOptionTestData))] + internal void TestParametersTlsOptionParse(string? tlsOptions, HiveServer2TlsOption expected, Type? exceptionType = default) + { + if (exceptionType == default) + Assert.Equal(expected, TlsOptionsParser.Parse(tlsOptions)); + else + Assert.Throws(exceptionType, () => TlsOptionsParser.Parse(tlsOptions)); + } + + public static IEnumerable GetParametersDataTypeConvTestData() + { + // Default + yield return new object?[] { null, DataTypeConversion.Scalar }; + yield return new object?[] { "", DataTypeConversion.Scalar }; + yield return new object?[] { ",", DataTypeConversion.Scalar }; + // Explicit + yield return new object?[] { $"scalar", DataTypeConversion.Scalar }; + yield return new object?[] { $"none", DataTypeConversion.None }; + // Ignore "empty", embedded space, mixed-case + yield return new object?[] { $"scalar,", DataTypeConversion.Scalar }; + yield return new object?[] { $",scalar,", DataTypeConversion.Scalar }; + yield return new object?[] { $",scAlAr,", DataTypeConversion.Scalar }; + yield return new object?[] { $"scAlAr", DataTypeConversion.Scalar }; + yield return new object?[] { $" scalar ", DataTypeConversion.Scalar }; + // Combined - conflicting + yield return new object?[] { $"none,scalar", DataTypeConversion.None | DataTypeConversion.Scalar, typeof(ArgumentOutOfRangeException) }; + yield return new object?[] { $" nOnE, scAlAr ", DataTypeConversion.None | DataTypeConversion.Scalar, typeof(ArgumentOutOfRangeException) }; + yield return new object?[] { $", none, scalar, ", DataTypeConversion.None | DataTypeConversion.Scalar , typeof(ArgumentOutOfRangeException) }; + yield return new object?[] { $"scalar,none", DataTypeConversion.None | DataTypeConversion.Scalar , typeof(ArgumentOutOfRangeException) }; + // Invalid options + yield return new object?[] { $"xxx", DataTypeConversion.Empty, typeof(ArgumentOutOfRangeException) }; + yield return new object?[] { $"none,scalar,xxx", DataTypeConversion.None | DataTypeConversion.Scalar, typeof(ArgumentOutOfRangeException) }; + } + + public static IEnumerable GetParametersTlsOptionTestData() + { + // Default + yield return new object?[] { null, HiveServer2TlsOption.Empty }; + yield return new object?[] { "", HiveServer2TlsOption.Empty}; + yield return new object?[] { " ", HiveServer2TlsOption.Empty }; + // Explicit + yield return new object?[] { $"{TlsOptions.AllowSelfSigned}", HiveServer2TlsOption.AllowSelfSigned }; + yield return new object?[] { $"{TlsOptions.AllowHostnameMismatch}", HiveServer2TlsOption.AllowHostnameMismatch }; + // Ignore empty + yield return new object?[] { $",{TlsOptions.AllowSelfSigned}", HiveServer2TlsOption.AllowSelfSigned }; + yield return new object?[] { $",{TlsOptions.AllowHostnameMismatch},", HiveServer2TlsOption.AllowHostnameMismatch }; + // Combined, embedded space, mixed-case + yield return new object?[] { $"{TlsOptions.AllowSelfSigned},{TlsOptions.AllowHostnameMismatch}", HiveServer2TlsOption.AllowSelfSigned | HiveServer2TlsOption.AllowHostnameMismatch }; + yield return new object?[] { $"{TlsOptions.AllowHostnameMismatch},{TlsOptions.AllowSelfSigned}", HiveServer2TlsOption.AllowSelfSigned | HiveServer2TlsOption.AllowHostnameMismatch }; + yield return new object?[] { $" {TlsOptions.AllowHostnameMismatch} , {TlsOptions.AllowSelfSigned} ", HiveServer2TlsOption.AllowSelfSigned | HiveServer2TlsOption.AllowHostnameMismatch }; + yield return new object?[] { $"{TlsOptions.AllowSelfSigned.ToUpperInvariant()},{TlsOptions.AllowHostnameMismatch.ToUpperInvariant()}", HiveServer2TlsOption.AllowSelfSigned | HiveServer2TlsOption.AllowHostnameMismatch }; + // Invalid + yield return new object?[] { $"xxx,{TlsOptions.AllowSelfSigned.ToUpperInvariant()},{TlsOptions.AllowHostnameMismatch.ToUpperInvariant()}", HiveServer2TlsOption.Empty, typeof(ArgumentOutOfRangeException) }; + } + } +} diff --git a/csharp/test/Drivers/Apache/Hive2/HiveServer2ReaderTest.cs b/csharp/test/Drivers/Apache/Hive2/HiveServer2ReaderTest.cs new file mode 100644 index 0000000000..d784e23518 --- /dev/null +++ b/csharp/test/Drivers/Apache/Hive2/HiveServer2ReaderTest.cs @@ -0,0 +1,265 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Globalization; +using System.Text; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Xunit; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2 +{ + public class HiveServer2ReaderTest + { + private const bool IsValid = true; + private const bool IsNotValid = false; + + [Theory] + [MemberData(nameof(GetDateTestData), /* isKnownFormat */ true)] + internal void TestCanConvertKnownFormatDate(string date, DateTime expected, bool isValid) + { + ReadOnlySpan dateSpan = Encoding.UTF8.GetBytes(date).AsSpan(); + if (isValid) + { + Assert.True(HiveServer2Reader.TryParse(dateSpan, out DateTime dateTime)); + Assert.Equal(expected, dateTime); + } + else + { + Assert.False(HiveServer2Reader.TryParse(dateSpan, out DateTime _)); + } + } + + [Theory] + [MemberData(nameof(GetDateTestData), /* isKnownFormat */ false)] + internal void TestCanConvertUnknownFormatDate(string date, DateTime expected, bool isValid) + { + var builder = new StringArray.Builder(); + builder.Append(date); + var stringArray = builder.Build(); + if (isValid) + { + var dateArray = HiveServer2Reader.ConvertToDate32(stringArray, stringArray.Data.DataType); + Assert.Equal(1, dateArray.Length); + Assert.Equal(expected, dateArray.GetDateTime(0)); + } + else + { + Assert.Throws(() => HiveServer2Reader.ConvertToDate32(stringArray, stringArray.Data.DataType)); + } + } + + [Theory] + [MemberData(nameof(GetTimestampTestData), /* isKnownFormat */ true)] + internal void TestCanConvertKnownFormatTimestamp(string date, DateTimeOffset expected, bool isValid) + { + ReadOnlySpan dateSpan = Encoding.UTF8.GetBytes(date).AsSpan(); + if (isValid) + { + Assert.True(HiveServer2Reader.TryParse(dateSpan, out DateTimeOffset dateTime)); + Assert.Equal(expected, dateTime); + } + else + { + Assert.False(HiveServer2Reader.TryParse(dateSpan, out DateTimeOffset _)); + } + } + + [Theory] + [MemberData(nameof(GetTimestampTestData), /* isKnownFormat */ false)] + internal void TestCanConvertUnknownFormatTimestamp(string date, DateTimeOffset expected, bool isValid) + { + var builder = new StringArray.Builder(); + builder.Append(date); + var stringArray = builder.Build(); + if (isValid) + { + TimestampArray timestampArray = HiveServer2Reader.ConvertToTimestamp(stringArray, stringArray.Data.DataType); + Assert.Equal(1, timestampArray.Length); + Assert.Equal(expected, timestampArray.GetTimestamp(0)); + } + else + { + Assert.Throws(() => HiveServer2Reader.ConvertToTimestamp(stringArray, stringArray.Data.DataType)); + } + } + + public static TheoryData GetDateTestData(bool isKnownFormat) + { + string[] dates = + [ + "0001-01-01", + "0001-12-31", + "1970-01-01", + "2024-12-31", + "9999-12-31", + ]; + + var data = new TheoryData(); + foreach (string date in dates) + { + data.Add(date, DateTime.Parse(date, CultureInfo.InvariantCulture), IsValid); + } + + // Conditionally invalid component separators + string[] leadingSpaces = ["", " "]; + string[] TrailingSpaces = ["", " "]; + string[] separators = ["/", " "]; + foreach (string leadingSpace in leadingSpaces) + { + foreach (string trailingSpace in TrailingSpaces) + { + foreach (string separator in separators) + { + foreach (string date in dates) + { + data.Add(leadingSpace + date.Replace("-", separator) + trailingSpace, DateTime.Parse(date), !isKnownFormat); + } + } + } + } + + // Always invalid for a date separator + separators = [":"]; + foreach (string leadingSpace in leadingSpaces) + { + foreach (string trailingSpace in TrailingSpaces) + { + foreach (string separator in separators) + { + foreach (string date in dates) + { + data.Add(leadingSpace + date.Replace("-", separator) + trailingSpace, default, IsNotValid); + } + } + } + } + + string[] invalidDates = + [ + "0001-01-00", + "0001-01-32", + "0001-02-30", + "0001-13-01", + "00a1-01-01", + "0001-a1-01", + "0001-01-a1", + "001a-01-01", + "0001-1a-01", + "0001-01-1a", + ]; + foreach (string date in invalidDates) + { + data.Add(date, default, IsNotValid); + } + + return data; + } + + public static TheoryData GetTimestampTestData(bool isKnownFormat) + { + string[] dates = + [ + "0001-01-01 00:00:00", + "9999-12-31 23:59:59", + "0001-01-01 00:00:00.1000000", + "0001-12-31 00:00:00.0100000", + "1970-01-01 00:00:00.0010000", + "2024-12-31 00:00:00.0001000", + "9999-12-31 00:00:00.0000100", + "9999-12-31 00:00:00.", + "9999-12-31 00:00:00.9", + "9999-12-31 00:00:00.99", + "9999-12-31 00:00:00.999", + "9999-12-31 00:00:00.9999", + "9999-12-31 00:00:00.99999", + "9999-12-31 00:00:00.999999", + "9999-12-31 00:00:00.999999", + "9999-12-31 00:00:00.9999990", + "9999-12-31 00:00:00.99999900", + ]; + + var data = new TheoryData(); + foreach (string date in dates) + { + data.Add(date, DateTimeOffset.Parse(date, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal), IsValid); + } + + // Conditionally invalid component separators + string[] leadingSpaces = ["", " "]; + string[] TrailingSpaces = ["", " "]; + string[] dateSeparators = ["/", " "]; + foreach (string leadingSpace in leadingSpaces) + { + foreach (string trailingSpace in TrailingSpaces) + { + foreach (string separator in dateSeparators) + { + foreach (string date in dates) + { + data.Add( + leadingSpace + date.Replace("-", separator) + trailingSpace, + DateTimeOffset.Parse(date, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal), + !isKnownFormat); + } + } + } + } + + // Always an invalid separator for date. + dateSeparators = [":"]; + foreach (string leadingSpace in leadingSpaces) + { + foreach (string trailingSpace in TrailingSpaces) + { + foreach (string separator in dateSeparators) + { + foreach (string date in dates) + { + data.Add(leadingSpace + date.Replace("-", separator) + trailingSpace, default, IsNotValid); + } + } + } + } + + string[] invalidDates = + [ + "0001-01-00 00:00:00", + "0001-01-32 00:00:00", + "0001-02-30 00:00:00", + "0001-13-01 00:00:00", + "abcd-13-01 00:00:00", + "0001-12-01 00:00:00.abc", + "00a1-01-01 00:00:00", + "0001-a1-01 00:00:00", + "0001-01-a1 00:00:00", + "0001-01-01 a0:00:00", + "0001-01-01 00:a0:00", + "0001-01-01 00:00:a0", + "001a-01-01 00:00:00", + "0010-1a-01 00:00:00", + "0010-10-1a 00:00:00", + ]; + foreach (string date in invalidDates) + { + data.Add(date, default, IsNotValid); + } + + return data; + } + } +} diff --git a/csharp/test/Drivers/Apache/Hive2/HiveServer2TestEnvironment.cs b/csharp/test/Drivers/Apache/Hive2/HiveServer2TestEnvironment.cs new file mode 100644 index 0000000000..f141b293c5 --- /dev/null +++ b/csharp/test/Drivers/Apache/Hive2/HiveServer2TestEnvironment.cs @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2 +{ + public abstract class HiveServer2TestEnvironment : TestEnvironment + where TConfig : TestConfiguration + { + public HiveServer2TestEnvironment(Func getConnection) + : base(getConnection) + { + } + + internal DataTypeConversion DataTypeConversion => ((HiveServer2Connection)Connection).DataTypeConversion; + + public string? GetValueForProtocolVersion(string? unconvertedValue, string? convertedValue) => + ((HiveServer2Connection)Connection).DataTypeConversion.HasFlag(DataTypeConversion.None) ? unconvertedValue : convertedValue; + + public object? GetValueForProtocolVersion(object? unconvertedValue, object? convertedValue) => + ((HiveServer2Connection)Connection).DataTypeConversion.HasFlag(DataTypeConversion.None) ? unconvertedValue : convertedValue; + + } +} diff --git a/csharp/test/Drivers/Apache/Impala/ImpalaTestEnvironment.cs b/csharp/test/Drivers/Apache/Impala/ImpalaTestEnvironment.cs new file mode 100644 index 0000000000..a81444cb1f --- /dev/null +++ b/csharp/test/Drivers/Apache/Impala/ImpalaTestEnvironment.cs @@ -0,0 +1,79 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Apache.Impala; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala +{ + public class ImpalaTestEnvironment : TestEnvironment + { + public class Factory : Factory + { + public override ImpalaTestEnvironment Create(Func getConnection) => new(getConnection); + } + + private ImpalaTestEnvironment(Func getConnection) : base(getConnection) { } + + public override string TestConfigVariable => "IMPALA_TEST_CONFIG_FILE"; + + public override string SqlDataResourceLocation => "Impala/Resources/ImpalaData.sql"; + + public override int ExpectedColumnCount => 17; + + public override AdbcDriver CreateNewDriver() => new ImpalaDriver(); + + public override string GetCreateTemporaryTableStatement(string tableName, string columns) + { + return string.Format("CREATE TABLE {0} ({1})", tableName, columns); + } + + public override string Delimiter => "`"; + + public override Dictionary GetDriverParameters(ApacheTestConfiguration testConfiguration) + { + Dictionary parameters = new(StringComparer.OrdinalIgnoreCase); + + if (!string.IsNullOrEmpty(testConfiguration.HostName)) + { + parameters.Add("HostName", testConfiguration.HostName!); + } + if (!string.IsNullOrEmpty(testConfiguration.Port)) + { + parameters.Add("Port", testConfiguration.Port!); + } + return parameters; + } + + public override string VendorVersion => ((HiveServer2Connection)Connection).VendorVersion; + + public override bool SupportsDelete => false; + + public override bool SupportsUpdate => false; + + public override bool SupportCatalogName => false; + + public override bool ValidateAffectedRows => false; + + public override string GetInsertStatement(string tableName, string columnName, string? value) => + string.Format("INSERT INTO {0} ({1}) SELECT {2};", tableName, columnName, value ?? "NULL"); + + public override SampleDataBuilder GetSampleDataBuilder() => throw new NotImplementedException(); + } +} diff --git a/csharp/test/Drivers/Apache/Impala/ImpalaTests.cs b/csharp/test/Drivers/Apache/Impala/ImpalaTests.cs index ae19e22247..f0eee3e64f 100644 --- a/csharp/test/Drivers/Apache/Impala/ImpalaTests.cs +++ b/csharp/test/Drivers/Apache/Impala/ImpalaTests.cs @@ -15,36 +15,30 @@ * limitations under the License. */ -using System; using System.Collections.Generic; using Apache.Arrow.Adbc.Drivers.Apache.Impala; using Apache.Arrow.Adbc.Tests.Xunit; using Xunit; +using Xunit.Abstractions; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala { [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] - public class ImpalaTests + public class ImpalaTests : TestBase { - [SkippableFact, Order(1)] - public void CanDriverConnect() + public ImpalaTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new ImpalaTestEnvironment.Factory()) { - ApacheTestConfiguration testConfiguration = Utils.GetTestConfiguration("impalaconfig.json"); - - Dictionary parameters = new Dictionary(StringComparer.OrdinalIgnoreCase) - { - { "HostName", testConfiguration.HostName }, - { "Port", testConfiguration.Port }, - }; + } - AdbcDatabase database = new ImpalaDriver().Open(parameters); - AdbcConnection connection = database.Connect(new Dictionary()); - AdbcStatement statement = connection.CreateStatement(); - statement.SqlQuery = testConfiguration.Query; + [SkippableFact, Order(1)] + public void CanExecuteQuery() + { + AdbcStatement statement = Connection.CreateStatement(); + statement.SqlQuery = TestConfiguration.Query; QueryResult queryResult = statement.ExecuteQuery(); - //Adbc.Tests.ConnectionTests.CanDriverConnect(queryResult, testConfiguration.ExpectedResultsCount); - + DriverTests.CanExecuteQuery(queryResult, TestConfiguration.ExpectedResultsCount); } } } diff --git a/csharp/test/Drivers/Apache/Impala/Resources/ImpalaData.sql b/csharp/test/Drivers/Apache/Impala/Resources/ImpalaData.sql new file mode 100644 index 0000000000..908ffbb930 --- /dev/null +++ b/csharp/test/Drivers/Apache/Impala/Resources/ImpalaData.sql @@ -0,0 +1,133 @@ + + -- Licensed to the Apache Software Foundation (ASF) under one or more + -- contributor license agreements. See the NOTICE file distributed with + -- this work for additional information regarding copyright ownership. + -- The ASF licenses this file to You under the Apache License, Version 2.0 + -- (the "License"); you may not use this file except in compliance with + -- the License. You may obtain a copy of the License at + + -- http://www.apache.org/licenses/LICENSE-2.0 + + -- Unless required by applicable law or agreed to in writing, software + -- distributed under the License is distributed on an "AS IS" BASIS, + -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + -- See the License for the specific language governing permissions and + -- limitations under the License. + +DROP TABLE IF EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE}; + +CREATE TABLE IF NOT EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id LONG, + byte BYTE, + short SHORT, + integer INT, + float FLOAT, + number DOUBLE, + decimal NUMERIC(38, 9), + is_active BOOLEAN, + name STRING, + data BINARY, + date DATE, + timestamp TIMESTAMP, + timestamp_ntz TIMESTAMP_NTZ, + timestamp_ltz TIMESTAMP_LTZ, + numbers ARRAY, + person STRUCT < + name STRING, + age LONG + >, + map MAP < + INT, + STRING + >, + varchar VARCHAR(255), + char CHAR(10) +); + +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id, + byte, short, integer, float, number, decimal, + is_active, + name, data, + date, timestamp, timestamp_ntz, timestamp_ltz, + numbers, + person, + map, + varchar, + char +) +VALUES ( + 1, + 2, 3, 4, 7.89, 1.23, 4.56, + TRUE, + 'John Doe', + -- hex-encoded value `abc123` + X'616263313233', + '2023-09-08', '2023-09-08 12:34:56', '2023-09-08 12:34:56', '2023-09-08 12:34:56+00:00', + ARRAY(1, 2, 3), + STRUCT('John Doe', 30), + MAP(1, 'John Doe'), + 'John Doe', + 'John Doe' +); + +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id, + byte, short, integer, float, number, decimal, + is_active, + name, data, + date, timestamp, timestamp_ntz, timestamp_ltz, + numbers, + person, + map, + varchar, + char +) +VALUES ( + 2, + 127, 32767, 2147483647, 3.4028234663852886e+38, 1.7976931348623157e+308, 9.99999999999999999999999999999999E+28BD, + FALSE, + 'Jane Doe', + -- hex-encoded `def456` + X'646566343536', + '2023-09-09', '2023-09-09 13:45:57', '2023-09-09 13:45:57', '2023-09-09 13:45:57+00:00', + ARRAY(4, 5, 6), + STRUCT('Jane Doe', 40), + MAP(1, 'John Doe'), + 'Jane Doe', + 'Jane Doe' +); + +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id, + byte, short, integer, float, number, decimal, + is_active, + name, data, + date, timestamp, timestamp_ntz, timestamp_ltz, + numbers, + person, + map, + varchar, + char +) +VALUES ( + 3, + -128, -32768, -2147483648, -3.4028234663852886e+38, -1.7976931348623157e+308, -9.99999999999999999999999999999999E+28BD, + FALSE, + 'Jack Doe', + -- hex-encoded `def456` + X'646566343536', + '1556-01-02', '1970-01-01 00:00:00', '1970-01-01 00:00:00', '9999-12-31 23:59:59+00:00', + ARRAY(7, 8, 9), + STRUCT('Jack Doe', 50), + MAP(1, 'John Doe'), + 'Jack Doe', + 'Jack Doe' +); + +UPDATE {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} + SET short = 0 + WHERE id = 3; + +DELETE FROM {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} + WHERE id = 3; diff --git a/csharp/test/Drivers/Apache/Impala/Resources/impalaconfig.json b/csharp/test/Drivers/Apache/Impala/Resources/impalaconfig.json index 550fd3a97c..acd5c1b983 100644 --- a/csharp/test/Drivers/Apache/Impala/Resources/impalaconfig.json +++ b/csharp/test/Drivers/Apache/Impala/Resources/impalaconfig.json @@ -1,6 +1,12 @@ { - "hostName": "", - "port": "", - "query": "", - "expectedResults": 0 + "environment": "Impala", + "hostName": "", + "port": "", + "query": "", + "expectedResults": 0, + "metadata": { + "schema": "", + "table": "", + "expectedColumnCount": 0 + } } diff --git a/csharp/test/Drivers/Apache/Spark/BinaryBooleanValueTests.cs b/csharp/test/Drivers/Apache/Spark/BinaryBooleanValueTests.cs index fa19558f64..192db59e37 100644 --- a/csharp/test/Drivers/Apache/Spark/BinaryBooleanValueTests.cs +++ b/csharp/test/Drivers/Apache/Spark/BinaryBooleanValueTests.cs @@ -15,94 +15,15 @@ * limitations under the License. */ -using System; -using System.Collections.Generic; -using System.Globalization; -using System.Threading.Tasks; -using Xunit; using Xunit.Abstractions; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark { - // TODO: When supported, use prepared statements instead of SQL string literals - // Which will better test how the driver handles values sent/received - - /// - /// Validates that specific binary and boolean values can be inserted, retrieved and targeted correctly - /// - public class BinaryBooleanValueTests : SparkTestBase + public class BinaryBooleanValueTests : Common.BinaryBooleanValueTests { - public BinaryBooleanValueTests(ITestOutputHelper output) : base(output) { } - - public static IEnumerable ByteArrayData(int size) - { - var rnd = new Random(); - byte[] bytes = new byte[size]; - rnd.NextBytes(bytes); - yield return new object[] { bytes }; - } - - /// - /// Validates if driver can send and receive specific Binary values correctly. - /// - [SkippableTheory] - [MemberData(nameof(ByteArrayData), 0)] - [MemberData(nameof(ByteArrayData), 2)] - [MemberData(nameof(ByteArrayData), 1024)] - public async Task TestBinaryData(byte[] value) - { - string columnName = "BINARYTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "BINARY")); - string formattedValue = $"X'{BitConverter.ToString(value).Replace("-", "")}'"; - await ValidateInsertSelectDeleteSingleValueAsync( - table.TableName, - columnName, - value, - formattedValue); - } - - /// - /// Validates if driver can send and receive specific Boolean values correctly. - /// - [SkippableTheory] - [InlineData(null)] - [InlineData(true)] - [InlineData(false)] - public async Task TestBooleanData(bool? value) - { - string columnName = "BOOLEANTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "BOOLEAN")); - string? formattedValue = value == null ? null : QuoteValue($"{value?.ToString(CultureInfo.InvariantCulture)}"); - await ValidateInsertSelectDeleteSingleValueAsync( - table.TableName, - columnName, - value, - formattedValue); - } - - /// - /// Validates if driver can receive specific NULL values correctly. - /// - [SkippableTheory] - [InlineData("NULL")] - [InlineData("CAST(NULL AS INT)")] - [InlineData("CAST(NULL AS BIGINT)")] - [InlineData("CAST(NULL AS SMALLINT)")] - [InlineData("CAST(NULL AS TINYINT)")] - [InlineData("CAST(NULL AS FLOAT)")] - [InlineData("CAST(NULL AS DOUBLE)")] - [InlineData("CAST(NULL AS DECIMAL(38,0))")] - [InlineData("CAST(NULL AS STRING)")] - [InlineData("CAST(NULL AS VARCHAR(10))")] - [InlineData("CAST(NULL AS CHAR(10))")] - [InlineData("CAST(NULL AS BOOLEAN)")] - // TODO: Returns byte[] [] (i.e., empty array) - expecting null value. - //[InlineData("CAST(NULL AS BINARY)", Skip = "Returns empty array - expecting null value.")] - public async Task TestNullData(string projectionClause) + public BinaryBooleanValueTests(ITestOutputHelper output) + : base(output, new SparkTestEnvironment.Factory()) { - string selectStatement = $"SELECT {projectionClause};"; - // Note: by default, this returns as String type, not NULL type. - await SelectAndValidateValuesAsync(selectStatement, null, 1); } } } diff --git a/csharp/test/Drivers/Apache/Spark/ClientTests.cs b/csharp/test/Drivers/Apache/Spark/ClientTests.cs new file mode 100644 index 0000000000..d5230749ae --- /dev/null +++ b/csharp/test/Drivers/Apache/Spark/ClientTests.cs @@ -0,0 +1,56 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark +{ + public class ClientTests : Common.ClientTests + { + public ClientTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new SparkTestEnvironment.Factory()) + { + } + + protected override IReadOnlyList GetUpdateExpectedResults() + { + int affectedRows = ValidateAffectedRows ? 1 : -1; + return GetUpdateExpecteResults(affectedRows, TestEnvironment.ServerType == SparkServerType.Databricks); + } + + internal static IReadOnlyList GetUpdateExpecteResults(int affectedRows, bool isDatabricks) + { + return !isDatabricks + ? [ + -1, // CREATE TABLE + affectedRows, // INSERT + affectedRows, // INSERT + affectedRows, // INSERT + ] + : [ + -1, // CREATE TABLE + affectedRows, // INSERT + affectedRows, // INSERT + affectedRows, // INSERT + affectedRows, // UPDATE + affectedRows, // DELETE + ]; + } + } +} diff --git a/csharp/test/Drivers/Apache/Spark/ComplexTypesValueTests.cs b/csharp/test/Drivers/Apache/Spark/ComplexTypesValueTests.cs index 2968c68f51..c2ca23b60a 100644 --- a/csharp/test/Drivers/Apache/Spark/ComplexTypesValueTests.cs +++ b/csharp/test/Drivers/Apache/Spark/ComplexTypesValueTests.cs @@ -15,65 +15,15 @@ * limitations under the License. */ -using System; -using System.Collections.Generic; -using System.Threading.Tasks; -using Xunit; using Xunit.Abstractions; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark { - // TODO: When supported, use prepared statements instead of SQL string literals - // Which will better test how the driver handles values sent/received - - /// - /// Validates that specific complex structured types can be inserted, retrieved and targeted correctly - /// - public class ComplexTypesValueTests : SparkTestBase + public class ComplexTypesValueTests : Common.ComplexTypesValueTests { - public ComplexTypesValueTests(ITestOutputHelper output) : base(output) { } - - /// - /// Validates if driver can send and receive specific array of integer values correctly. - /// - [SkippableTheory] - [InlineData("ARRAY(CAST(1 AS INT), 2, 3)", "[1,2,3]")] - [InlineData("ARRAY(CAST(1 AS LONG), 2, 3)", "[1,2,3]")] - [InlineData("ARRAY(CAST(1 AS DOUBLE), 2, 3)", "[1.0,2.0,3.0]")] - [InlineData("ARRAY(CAST(1 AS NUMERIC(38,0)), 2, 3)", "[1,2,3]")] - [InlineData("ARRAY(CAST('John Doe' AS STRING), 2, 3)", """["John Doe","2","3"]""")] - // Note: Timestamp returned adjusted to UTC. - [InlineData("ARRAY(CAST('2024-01-01T00:00:00-07:00' AS TIMESTAMP_LTZ), CAST('2024-02-02T02:02:02+01:30' AS TIMESTAMP_LTZ), CAST('2024-03-03T03:03:03Z' AS TIMESTAMP_LTZ))", """[2024-01-01 07:00:00,2024-02-02 00:32:02,2024-03-03 03:03:03]""")] - [InlineData("ARRAY(CAST('2024-01-01T00:00:00Z' AS DATE), CAST('2024-02-02T02:02:02Z' AS DATE), CAST('2024-03-03T03:03:03Z' AS DATE))", """[2024-01-01,2024-02-02,2024-03-03]""")] - [InlineData("ARRAY(INTERVAL 123 YEARS 11 MONTHS, INTERVAL 5 YEARS, INTERVAL 6 MONTHS)", """[123-11,5-0,0-6]""")] - public async Task TestArrayData(string projection, string value) - { - string selectStatement = $"SELECT {projection};"; - await SelectAndValidateValuesAsync(selectStatement, value, 1); - } - - /// - /// Validates if driver can send and receive specific map values correctly. - /// - [SkippableTheory] - [InlineData("MAP(1, 'John Doe', 2, 'Jane Doe', 3, 'Jack Doe')", """{1:"John Doe",2:"Jane Doe",3:"Jack Doe"}""")] - [InlineData("MAP('John Doe', 1, 'Jane Doe', 2, 'Jack Doe', 3)", """{"Jack Doe":3,"Jane Doe":2,"John Doe":1}""")] - public async Task TestMapData(string projection, string value) - { - string selectStatement = $"SELECT {projection};"; - await SelectAndValidateValuesAsync(selectStatement, value, 1); - } - - /// - /// Validates if driver can send and receive specific map values correctly. - /// - [SkippableTheory] - [InlineData("STRUCT(CAST(1 AS INT), CAST('John Doe' AS STRING))", """{"col1":1,"col2":"John Doe"}""")] - [InlineData("STRUCT(CAST('John Doe' AS STRING), CAST(1 AS INT))", """{"col1":"John Doe","col2":1}""")] - public async Task TestStructData(string projection, string value) + public ComplexTypesValueTests(ITestOutputHelper output) + : base(output, new SparkTestEnvironment.Factory()) { - string selectStatement = $"SELECT {projection};"; - await SelectAndValidateValuesAsync(selectStatement, value, 1); } } } diff --git a/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs b/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs index c2278f3886..e9d832ab9e 100644 --- a/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs +++ b/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs @@ -16,58 +16,26 @@ */ using System; -using System.Collections.Generic; using System.Globalization; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; using Xunit; using Xunit.Abstractions; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark { - // TODO: When supported, use prepared statements instead of SQL string literals - // Which will better test how the driver handles values sent/received - - /// - /// Validates that specific date, timestamp and interval values can be inserted, retrieved and targeted correctly - /// - public class DateTimeValueTests : SparkTestBase + public class DateTimeValueTests : Common.DateTimeValueTests { - // Spark handles microseconds but not nanoseconds. Truncated to 6 decimal places. - const string DateTimeZoneFormat = "yyyy-MM-dd'T'HH:mm:ss'.'ffffffK"; - const string DateFormat = "yyyy-MM-dd"; - - private static readonly DateTimeOffset[] s_timestampValues = new[] - { -#if NET5_0_OR_GREATER - DateTimeOffset.UnixEpoch, -#endif - DateTimeOffset.MinValue, - DateTimeOffset.MaxValue, - DateTimeOffset.UtcNow, - DateTimeOffset.UtcNow.ToOffset(TimeSpan.FromHours(4)) - }; + public DateTimeValueTests(ITestOutputHelper output) + : base(output, new SparkTestEnvironment.Factory()) + { } - public DateTimeValueTests(ITestOutputHelper output) : base(output) { } - - /// - /// Validates if driver can send and receive specific Timstamp values correctly - /// [SkippableTheory] - [MemberData(nameof(TimestampData), "TIMESTAMP")] [MemberData(nameof(TimestampData), "TIMESTAMP_LTZ")] - public async Task TestTimestampData(DateTimeOffset value, string columnType) + public async Task TestTimestampDataDatabricks(DateTimeOffset value, string columnType) { - string columnName = "TIMESTAMPTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, columnType)); - - string formattedValue = $"{value.ToString(DateTimeZoneFormat, CultureInfo.InvariantCulture)}"; - DateTimeOffset truncatedValue = DateTimeOffset.ParseExact(formattedValue, DateTimeZoneFormat, CultureInfo.InvariantCulture); - - await ValidateInsertSelectDeleteSingleValueAsync( - table.TableName, - columnName, - truncatedValue, - QuoteValue(formattedValue)); + Skip.If(TestEnvironment.ServerType != SparkServerType.Databricks); + await base.TestTimestampData(value, columnType); } /// @@ -75,11 +43,9 @@ await ValidateInsertSelectDeleteSingleValueAsync( /// [SkippableTheory] [MemberData(nameof(TimestampData), "TIMESTAMP_NTZ")] - public async Task TestTimestampNoTimezoneData(DateTimeOffset value, string columnType) + public async Task TestTimestampNoTimezoneDataDatabricks(DateTimeOffset value, string columnType) { - // Note: Minimum value falls outside range of valid values on server when no time zone is included. Cannot be selected - Skip.If(value == DateTimeOffset.MinValue); - + Skip.If(TestEnvironment.ServerType != SparkServerType.Databricks); string columnName = "TIMESTAMPTYPE"; using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, columnType)); @@ -93,70 +59,5 @@ await ValidateInsertSelectDeleteSingleValueAsync( new DateTimeOffset(truncatedValue.DateTime, TimeSpan.Zero), QuoteValue(formattedValue)); } - - /// - /// Validates if driver can send and receive specific no timezone Timstamp values correctly - /// - [SkippableTheory] - [MemberData(nameof(TimestampData), "DATE")] - public async Task TestDateData(DateTimeOffset value, string columnType) - { - string columnName = "DATETYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, columnType)); - - string formattedValue = $"{value.ToString(DateFormat, CultureInfo.InvariantCulture)}"; - DateTimeOffset truncatedValue = DateTimeOffset.ParseExact(formattedValue, DateFormat, CultureInfo.InvariantCulture); - - await ValidateInsertSelectDeleteSingleValueAsync( - table.TableName, - columnName, - // Remove timezone offset - new DateTimeOffset(truncatedValue.DateTime, TimeSpan.Zero), - QuoteValue(formattedValue)); - } - - /// - /// Tests INTERVAL data types (YEAR-MONTH and DAY-SECOND). - /// - /// The INTERVAL to test. - /// The expected return value. - /// - [SkippableTheory] - [InlineData("INTERVAL 1 YEAR", "1-0")] - [InlineData("INTERVAL 1 YEAR 2 MONTH", "1-2")] - [InlineData("INTERVAL 2 MONTHS", "0-2")] - [InlineData("INTERVAL -1 YEAR", "-1-0")] - [InlineData("INTERVAL -1 YEAR 2 MONTH", "-0-10")] - [InlineData("INTERVAL -2 YEAR 2 MONTH", "-1-10")] - [InlineData("INTERVAL 1 YEAR -2 MONTH", "0-10")] - [InlineData("INTERVAL 178956970 YEAR", "178956970-0")] - [InlineData("INTERVAL 178956969 YEAR 11 MONTH", "178956969-11")] - [InlineData("INTERVAL -178956970 YEAR", "-178956970-0")] - [InlineData("INTERVAL 0 DAYS 0 HOURS 0 MINUTES 0 SECONDS", "0 00:00:00.000000000")] - [InlineData("INTERVAL 1 DAYS", "1 00:00:00.000000000")] - [InlineData("INTERVAL 2 HOURS", "0 02:00:00.000000000")] - [InlineData("INTERVAL 3 MINUTES", "0 00:03:00.000000000")] - [InlineData("INTERVAL 4 SECONDS", "0 00:00:04.000000000")] - [InlineData("INTERVAL 1 DAYS 2 HOURS", "1 02:00:00.000000000")] - [InlineData("INTERVAL 1 DAYS 2 HOURS 3 MINUTES", "1 02:03:00.000000000")] - [InlineData("INTERVAL 1 DAYS 2 HOURS 3 MINUTES 4 SECONDS", "1 02:03:04.000000000")] - [InlineData("INTERVAL 1 DAYS 2 HOURS 3 MINUTES 4.123123123 SECONDS", "1 02:03:04.123123000")] // Only to microseconds - [InlineData("INTERVAL 106751990 DAYS 23 HOURS 59 MINUTES 59.999999 SECONDS", "106751990 23:59:59.999999000")] - [InlineData("INTERVAL 106751991 DAYS 0 HOURS 0 MINUTES 0 SECONDS", "106751991 00:00:00.000000000")] - [InlineData("INTERVAL -106751991 DAYS 0 HOURS 0 MINUTES 0 SECONDS", "-106751991 00:00:00.000000000")] - [InlineData("INTERVAL -106751991 DAYS 23 HOURS 59 MINUTES 59.999999 SECONDS", "-106751990 00:00:00.000001000")] - public async Task TestIntervalData(string intervalClause, string value) - { - string selectStatement = $"SELECT {intervalClause} AS INTERVAL_VALUE;"; - await SelectAndValidateValuesAsync(selectStatement, value, 1); - } - - public static IEnumerable TimestampData(string columnType) - { - foreach (DateTimeOffset timestamp in s_timestampValues) - { - yield return new object[] { timestamp, columnType }; - } - } } } diff --git a/csharp/test/Drivers/Apache/Spark/DriverTests.cs b/csharp/test/Drivers/Apache/Spark/DriverTests.cs index 8fb341f872..6970777c22 100644 --- a/csharp/test/Drivers/Apache/Spark/DriverTests.cs +++ b/csharp/test/Drivers/Apache/Spark/DriverTests.cs @@ -15,535 +15,47 @@ * limitations under the License. */ -using System; using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using Apache.Arrow.Adbc.Tests.Metadata; -using Apache.Arrow.Adbc.Tests.Xunit; -using Apache.Arrow.Ipc; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; using Xunit; using Xunit.Abstractions; -using ColumnTypeId = Apache.Arrow.Adbc.Drivers.Apache.Spark.SparkConnection.ColumnTypeId; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark { - /// - /// Class for testing the Snowflake ADBC driver connection tests. - /// - /// - /// Tests are ordered to ensure data is created for the other - /// queries to run. - /// - [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] - public class DriverTests : SparkTestBase + public class DriverTests : Common.DriverTests { - /// - /// Supported Spark data types as a subset of - /// - private enum SupportedSparkDataType : short + public DriverTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new SparkTestEnvironment.Factory()) { - ARRAY = ColumnTypeId.ARRAY, - BIGINT = ColumnTypeId.BIGINT, - BINARY = ColumnTypeId.BINARY, - BOOLEAN = ColumnTypeId.BOOLEAN, - CHAR = ColumnTypeId.CHAR, - DATE = ColumnTypeId.DATE, - DECIMAL = ColumnTypeId.DECIMAL, - DOUBLE = ColumnTypeId.DOUBLE, - FLOAT = ColumnTypeId.FLOAT, - INTEGER = ColumnTypeId.INTEGER, - JAVA_OBJECT = ColumnTypeId.JAVA_OBJECT, - LONGNVARCHAR = ColumnTypeId.LONGNVARCHAR, - LONGVARBINARY = ColumnTypeId.LONGVARBINARY, - LONGVARCHAR = ColumnTypeId.LONGVARCHAR, - NCHAR = ColumnTypeId.NCHAR, - NULL = ColumnTypeId.NULL, - NUMERIC = ColumnTypeId.NUMERIC, - NVARCHAR = ColumnTypeId.NVARCHAR, - REAL = ColumnTypeId.REAL, - SMALLINT = ColumnTypeId.SMALLINT, - STRUCT = ColumnTypeId.STRUCT, - TIMESTAMP = ColumnTypeId.TIMESTAMP, - TINYINT = ColumnTypeId.TINYINT, - VARBINARY = ColumnTypeId.VARBINARY, - VARCHAR = ColumnTypeId.VARCHAR, } - private static List DefaultTableTypes => new() { "TABLE", "VIEW" }; - - public DriverTests(ITestOutputHelper? outputHelper) : base(outputHelper) - { - Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); - } - - /// - /// Validates if the driver can execute update statements. - /// - [SkippableFact, Order(1)] - public void CanExecuteUpdate() - { - AdbcConnection adbcConnection = NewConnection(); - - string[] queries = GetQueries(); - - List expectedResults = new() { - -1, // DROP TABLE - -1, // CREATE TABLE - 1, // INSERT - 1, // INSERT - 1, // INSERT - 1, // UPDATE - 1, // DELETE - }; - - for (int i = 0; i < queries.Length; i++) - { - string query = queries[i]; - AdbcStatement statement = adbcConnection.CreateStatement(); - statement.SqlQuery = query; - - UpdateResult updateResult = statement.ExecuteUpdate(); - - Assert.Equal(expectedResults[i], updateResult.AffectedRows); - } - } - - /// - /// Validates if the driver can call GetInfo. - /// - [SkippableFact, Order(2)] - public async Task CanGetInfo() - { - AdbcConnection adbcConnection = NewConnection(); - - // Test the supported info codes - List handledCodes = new List() - { - AdbcInfoCode.DriverName, - AdbcInfoCode.DriverVersion, - AdbcInfoCode.VendorName, - AdbcInfoCode.DriverArrowVersion, - AdbcInfoCode.VendorVersion, - AdbcInfoCode.VendorSql - }; - using IArrowArrayStream stream = adbcConnection.GetInfo(handledCodes); - - RecordBatch recordBatch = await stream.ReadNextRecordBatchAsync(); - UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name"); - - List expectedValues = new List() - { - "DriverName", - "DriverVersion", - "VendorName", - "DriverArrowVersion", - "VendorVersion", - "VendorSql" - }; - - for (int i = 0; i < infoNameArray.Length; i++) - { - AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i); - DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value"); - - Assert.Contains(value.ToString(), expectedValues); - - switch (value) - { - case AdbcInfoCode.VendorSql: - // TODO: How does external developer know the second field is the boolean field? - BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1]; - bool? boolValue = booleanArray.GetValue(i); - OutputHelper?.WriteLine($"{value}={boolValue}"); - Assert.True(boolValue); - break; - default: - StringArray stringArray = (StringArray)valueArray.Fields[0]; - string stringValue = stringArray.GetString(i); - OutputHelper?.WriteLine($"{value}={stringValue}"); - Assert.NotNull(stringValue); - break; - } - } - - // Test the unhandled info codes. - List unhandledCodes = new List() - { - AdbcInfoCode.VendorArrowVersion, - AdbcInfoCode.VendorSubstrait, - AdbcInfoCode.VendorSubstraitMaxVersion - }; - using IArrowArrayStream stream2 = adbcConnection.GetInfo(unhandledCodes); - - recordBatch = await stream2.ReadNextRecordBatchAsync(); - infoNameArray = (UInt32Array)recordBatch.Column("info_name"); - - List unexpectedValues = new List() - { - "VendorArrowVersion", - "VendorSubstrait", - "VendorSubstraitMaxVersion" - }; - for (int i = 0; i < infoNameArray.Length; i++) - { - AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i); - DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value"); - - Assert.Contains(value.ToString(), unexpectedValues); - switch (value) - { - case AdbcInfoCode.VendorSql: - BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1]; - Assert.Null(booleanArray.GetValue(i)); - break; - default: - StringArray stringArray = (StringArray)valueArray.Fields[0]; - Assert.Null(stringArray.GetString(i)); - break; - } - } - } - - /// - /// Validates if the driver can call GetObjects with GetObjectsDepth as Catalogs with CatalogPattern as a pattern. - /// - /// - [SkippableTheory, Order(3)] + [SkippableTheory] [MemberData(nameof(CatalogNamePatternData))] - public void GetGetObjectsCatalogs(string pattern) + public override void CanGetObjectsCatalogs(string? pattern) { - string? catalogName = TestConfiguration.Metadata.Catalog; - string? schemaName = TestConfiguration.Metadata.Schema; - - using IArrowArrayStream stream = Connection.GetObjects( - depth: AdbcConnection.GetObjectsDepth.Catalogs, - catalogPattern: pattern, - dbSchemaPattern: null, - tableNamePattern: null, - tableTypes: DefaultTableTypes, - columnNamePattern: null); - - using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, catalogName, null); - AdbcCatalog? catalog = catalogs.Where((catalog) => string.Equals(catalog.Name, catalogName)).FirstOrDefault(); - - Assert.True(catalog != null, "catalog should not be null"); + GetObjectsCatalogsTest(pattern); } - /// - /// Validates if the driver can call GetObjects with GetObjectsDepth as DbSchemas with DbSchemaName as a pattern. - /// - [SkippableTheory, Order(4)] + [SkippableTheory] [MemberData(nameof(DbSchemasNamePatternData))] - public void CanGetObjectsDbSchemas(string dbSchemaPattern) + public override void CanGetObjectsDbSchemas(string dbSchemaPattern) { - // need to add the database - string? databaseName = TestConfiguration.Metadata.Catalog; - string? schemaName = TestConfiguration.Metadata.Schema; - - using IArrowArrayStream stream = Connection.GetObjects( - depth: AdbcConnection.GetObjectsDepth.DbSchemas, - catalogPattern: databaseName, - dbSchemaPattern: dbSchemaPattern, - tableNamePattern: null, - tableTypes: DefaultTableTypes, - columnNamePattern: null); - - using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); - - List? dbSchemas = catalogs - .Where(c => string.Equals(c.Name, databaseName)) - .Select(c => c.DbSchemas) - .FirstOrDefault(); - AdbcDbSchema? dbSchema = dbSchemas?.Where((dbSchema) => string.Equals(dbSchema.Name, schemaName)).FirstOrDefault(); - - Assert.True(dbSchema != null, "dbSchema should not be null"); + GetObjectsDbSchemasTest(dbSchemaPattern); } - /// - /// Validates if the driver can call GetObjects with GetObjectsDepth as Tables with TableName as a pattern. - /// - [SkippableTheory, Order(5)] + [SkippableTheory] [MemberData(nameof(TableNamePatternData))] - public void CanGetObjectsTables(string tableNamePattern) - { - // need to add the database - string? databaseName = TestConfiguration.Metadata.Catalog; - string? schemaName = TestConfiguration.Metadata.Schema; - string? tableName = TestConfiguration.Metadata.Table; - - using IArrowArrayStream stream = Connection.GetObjects( - depth: AdbcConnection.GetObjectsDepth.Tables, - catalogPattern: databaseName, - dbSchemaPattern: schemaName, - tableNamePattern: tableNamePattern, - tableTypes: DefaultTableTypes, - columnNamePattern: null); - - using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); - - List? tables = catalogs - .Where(c => string.Equals(c.Name, databaseName)) - .Select(c => c.DbSchemas) - .FirstOrDefault() - ?.Where(s => string.Equals(s.Name, schemaName)) - .Select(s => s.Tables) - .FirstOrDefault(); - - AdbcTable? table = tables?.Where((table) => string.Equals(table.Name, tableName)).FirstOrDefault(); - Assert.True(table != null, "table should not be null"); - // TODO: Determine why this is returned blank. - //Assert.Equal("TABLE", table.Type); - } - - /// - /// Validates if the driver can call GetObjects for GetObjectsDepth as All. - /// - [SkippableFact, Order(6)] - public void CanGetObjectsAll() - { - // need to add the database - string? databaseName = TestConfiguration.Metadata.Catalog; - string? schemaName = TestConfiguration.Metadata.Schema; - string? tableName = TestConfiguration.Metadata.Table; - string? columnName = null; - - using IArrowArrayStream stream = Connection.GetObjects( - depth: AdbcConnection.GetObjectsDepth.All, - catalogPattern: databaseName, - dbSchemaPattern: schemaName, - tableNamePattern: tableName, - tableTypes: DefaultTableTypes, - columnNamePattern: columnName); - - using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); - AdbcTable? table = catalogs - .Where(c => string.Equals(c.Name, databaseName)) - .Select(c => c.DbSchemas) - .FirstOrDefault() - ?.Where(s => string.Equals(s.Name, schemaName)) - .Select(s => s.Tables) - .FirstOrDefault() - ?.Where(t => string.Equals(t.Name, tableName)) - .FirstOrDefault(); - - Assert.True(table != null, "table should not be null"); - // TODO: Determine why this is returned blank. - //Assert.Equal("TABLE", table.Type); - List? columns = table.Columns; - - Assert.True(columns != null, "Columns cannot be null"); - Assert.Equal(TestConfiguration.Metadata.ExpectedColumnCount, columns.Count); - - for (int i = 0; i < columns.Count; i++) - { - // Verify column metadata is returned/consistent. - AdbcColumn column = columns[i]; - Assert.Equal(i + 1, column.OrdinalPosition); - Assert.False(string.IsNullOrEmpty(column.Name)); - Assert.False(string.IsNullOrEmpty(column.XdbcTypeName)); - - var supportedTypes = Enum.GetValues(typeof(SupportedSparkDataType)).Cast(); - Assert.Contains((SupportedSparkDataType)column.XdbcSqlDataType!, supportedTypes); - Assert.Equal(column.XdbcDataType, column.XdbcSqlDataType); - - Assert.NotNull(column.XdbcDataType); - Assert.Contains((SupportedSparkDataType)column.XdbcDataType!, supportedTypes); - - HashSet typesHaveColumnSize = new() - { - (short)SupportedSparkDataType.DECIMAL, - (short)SupportedSparkDataType.NUMERIC, - (short)SupportedSparkDataType.CHAR, - (short)SupportedSparkDataType.VARCHAR, - }; - HashSet typesHaveDecimalDigits = new() - { - (short)SupportedSparkDataType.DECIMAL, - (short)SupportedSparkDataType.NUMERIC, - }; - - bool typeHasColumnSize = typesHaveColumnSize.Contains(column.XdbcDataType.Value); - Assert.Equal(column.XdbcColumnSize.HasValue, typeHasColumnSize); - - bool typeHasDecimalDigits = typesHaveDecimalDigits.Contains(column.XdbcDataType.Value); - Assert.Equal(column.XdbcDecimalDigits.HasValue, typeHasDecimalDigits); - - Assert.False(string.IsNullOrEmpty(column.Remarks)); - - Assert.NotNull(column.XdbcColumnDef); - - Assert.NotNull(column.XdbcNullable); - Assert.Contains(new short[] { 1, 0 }, i => i == column.XdbcNullable); - - Assert.NotNull(column.XdbcIsNullable); - Assert.Contains(new string[] { "YES", "NO" }, i => i.Equals(column.XdbcIsNullable)); - - Assert.NotNull(column.XdbcIsAutoIncrement); - - Assert.Null(column.XdbcCharOctetLength); - Assert.Null(column.XdbcDatetimeSub); - Assert.Null(column.XdbcNumPrecRadix); - Assert.Null(column.XdbcScopeCatalog); - Assert.Null(column.XdbcScopeSchema); - Assert.Null(column.XdbcScopeTable); - } - } - - /// - /// Validates if the driver can call GetObjects with GetObjectsDepth as Tables with TableName as a Special Character. - /// - [SkippableTheory, Order(7)] - [InlineData("MyIdentifier")] - [InlineData("ONE")] - [InlineData("mYiDentifier")] - [InlineData("3rd_identifier")] - // Note: Tables in 'hive_metastore' only support ASCII alphabetic, numeric and underscore. - public void CanGetObjectsTablesWithSpecialCharacter(string tableName) + public override void CanGetObjectsTables(string tableNamePattern) { - string catalogName = TestConfiguration.Metadata.Catalog; - string schemaPrefix = Guid.NewGuid().ToString().Replace("-", ""); - using TemporarySchema schema = TemporarySchema.NewTemporarySchemaAsync(catalogName, Statement).Result; - string schemaName = schema.SchemaName; - string fullTableName = $"{DelimitIdentifier(catalogName)}.{DelimitIdentifier(schemaName)}.{DelimitIdentifier(tableName)}"; - using TemporaryTable temporaryTable = TemporaryTable.NewTemporaryTableAsync(Statement, fullTableName, $"CREATE TABLE IF NOT EXISTS {fullTableName} (INDEX INT)").Result; - - using IArrowArrayStream stream = Connection.GetObjects( - depth: AdbcConnection.GetObjectsDepth.Tables, - catalogPattern: catalogName, - dbSchemaPattern: schemaName, - tableNamePattern: tableName, - tableTypes: DefaultTableTypes, - columnNamePattern: null); - - using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, catalogName, schemaName); - - List? tables = catalogs - .Where(c => string.Equals(c.Name, catalogName)) - .Select(c => c.DbSchemas) - .FirstOrDefault() - ?.Where(s => string.Equals(s.Name, schemaName)) - .Select(s => s.Tables) - .FirstOrDefault(); - - AdbcTable? table = tables?.FirstOrDefault(); - - Assert.True(table != null, "table should not be null"); - Assert.Equal(tableName, table.Name, true); + GetObjectsTablesTest(tableNamePattern); } - /// - /// Validates if the driver can call GetTableSchema. - /// - [SkippableFact, Order(8)] - public void CanGetTableSchema() + protected override IReadOnlyList GetUpdateExpectedResults() { - AdbcConnection adbcConnection = NewConnection(); - - string? catalogName = TestConfiguration.Metadata.Catalog; - string? schemaName = TestConfiguration.Metadata.Schema; - string tableName = TestConfiguration.Metadata.Table!; - - Schema schema = adbcConnection.GetTableSchema(catalogName, schemaName, tableName); - - int numberOfFields = schema.FieldsList.Count; - - Assert.Equal(TestConfiguration.Metadata.ExpectedColumnCount, numberOfFields); + int affectedRows = ValidateAffectedRows ? 1 : -1; + return ClientTests.GetUpdateExpecteResults(affectedRows, TestEnvironment.ServerType == SparkServerType.Databricks); } - /// - /// Validates if the driver can call GetTableTypes. - /// - [SkippableFact, Order(9)] - public async Task CanGetTableTypes() - { - AdbcConnection adbcConnection = NewConnection(); - - using IArrowArrayStream arrowArrayStream = adbcConnection.GetTableTypes(); - - RecordBatch recordBatch = await arrowArrayStream.ReadNextRecordBatchAsync(); - - StringArray stringArray = (StringArray)recordBatch.Column("table_type"); - - List known_types = new List - { - "TABLE", "VIEW" - }; - - int results = 0; - - for (int i = 0; i < stringArray.Length; i++) - { - string value = stringArray.GetString(i); - - if (known_types.Contains(value)) - { - results++; - } - } - - Assert.Equal(known_types.Count, results); - } - - /// - /// Validates if the driver can connect to a live server and - /// parse the results. - /// - [SkippableFact, Order(10)] - public void CanExecuteQuery() - { - using AdbcConnection adbcConnection = NewConnection(); - - using AdbcStatement statement = adbcConnection.CreateStatement(); - statement.SqlQuery = TestConfiguration.Query; - - QueryResult queryResult = statement.ExecuteQuery(); - - Tests.DriverTests.CanExecuteQuery(queryResult, TestConfiguration.ExpectedResultsCount); - } - - /// - /// Validates if the driver can connect to a live server and - /// parse the results using the asynchronous methods. - /// - [SkippableFact, Order(11)] - public async Task CanExecuteQueryAsync() - { - using AdbcConnection adbcConnection = NewConnection(); - using AdbcStatement statement = adbcConnection.CreateStatement(); - - statement.SqlQuery = TestConfiguration.Query; - QueryResult queryResult = await statement.ExecuteQueryAsync(); - - await Tests.DriverTests.CanExecuteQueryAsync(queryResult, TestConfiguration.ExpectedResultsCount); - } - - /// - /// Validates if the driver can connect to a live server and - /// perform and update asynchronously. - /// - [SkippableFact, Order(12)] - public async Task CanExecuteUpdateAsync() - { - using AdbcConnection adbcConnection = NewConnection(); - using AdbcStatement statement = adbcConnection.CreateStatement(); - using TemporaryTable temporaryTable = await NewTemporaryTableAsync(statement, "INDEX INT"); - - statement.SqlQuery = GetInsertValueStatement(temporaryTable.TableName, "INDEX", "1"); - UpdateResult updateResult = await statement.ExecuteUpdateAsync(); - - Assert.Equal(1, updateResult.AffectedRows); - } public static IEnumerable CatalogNamePatternData() { diff --git a/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs b/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs index f84371e1d2..417cfc3c64 100644 --- a/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs +++ b/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs @@ -15,249 +15,15 @@ * limitations under the License. */ -using System.Data.SqlTypes; -using System.Threading.Tasks; -using Apache.Arrow.Adbc.Drivers.Apache.Hive2; -using Xunit; using Xunit.Abstractions; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark { - // TODO: When supported, use prepared statements instead of SQL string literals - // Which will better test how the driver handles values sent/received - - public class NumericValueTests : SparkTestBase + public class NumericValueTests : Common.NumericValueTests { - /// - /// Validates that specific numeric values can be inserted, retrieved and targeted correctly - /// - public NumericValueTests(ITestOutputHelper output) : base(output) { } - - /// - /// Validates if driver can send and receive specific Integer values correctly - /// - [SkippableTheory] - [InlineData(-1)] - [InlineData(0)] - [InlineData(1)] - public async Task TestIntegerSanity(int value) - { - string columnName = "INTTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} INT", columnName)); - await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, value); - } - - /// - /// Validates if driver can handle the largest / smallest numbers - /// - [SkippableTheory] - [InlineData(int.MaxValue)] - [InlineData(int.MinValue)] - public async Task TestIntegerMinMax(int value) - { - string columnName = "INTTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} INT", columnName)); - await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, value); - } - - /// - /// Validates if driver can handle the largest / smallest numbers - /// - [SkippableTheory] - [InlineData(long.MaxValue)] - [InlineData(long.MinValue)] - public async Task TestLongMinMax(long value) - { - string columnName = "BIGINTTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} BIGINT", columnName)); - await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, value); - } - - /// - /// Validates if driver can handle the largest / smallest numbers - /// - [SkippableTheory] - [InlineData(short.MaxValue)] - [InlineData(short.MinValue)] - public async Task TestSmallIntMinMax(short value) - { - string columnName = "SMALLINTTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} SMALLINT", columnName)); - await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, value); - } - - /// - /// Validates if driver can handle the largest / smallest numbers - /// - [SkippableTheory] - [InlineData(sbyte.MaxValue)] - [InlineData(sbyte.MinValue)] - public async Task TestTinyIntMinMax(sbyte value) - { - string columnName = "TINYINTTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} TINYINT", columnName)); - await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, value); - } - - /// - /// Validates if driver can handle smaller Number type correctly - /// - [SkippableTheory] - [InlineData("-1")] - [InlineData("0")] - [InlineData("1")] - [InlineData("99")] - [InlineData("-99")] - public async Task TestSmallNumberRange(string value) - { - string columnName = "SMALLNUMBER"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(2,0)", columnName)); - await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, SqlDecimal.Parse(value)); - } - - /// - /// Validates if driver correctly errors out when the values exceed the column's limit - /// - [SkippableTheory] - [InlineData(-100)] - [InlineData(100)] - [InlineData(int.MaxValue)] - [InlineData(int.MinValue)] - public async Task TestSmallNumberRangeOverlimit(int value) - { - string columnName = "SMALLNUMBER"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(2,0)", columnName)); - await Assert.ThrowsAsync(async () => await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, new SqlDecimal(value))); - } - - /// - /// Validates if driver can handle a large scale Number type correctly - /// - [SkippableTheory] - [InlineData("0")] - [InlineData("-2.003")] - [InlineData("4.85")] - [InlineData("0.0000000000000000000000000000000000001")] - [InlineData("9.5545204502636499875576383003668916798")] - public async Task TestLargeScaleNumberRange(string value) - { - string columnName = "LARGESCALENUMBER"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,37)", columnName)); - await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, SqlDecimal.Parse(value)); - } - - /// - /// Validates if driver can error handle when input goes beyond a large scale Number type - /// - [SkippableTheory] - [InlineData("-10")] - [InlineData("10")] - [InlineData("99999999999999999999999999999999999999")] - [InlineData("-99999999999999999999999999999999999999")] - public async Task TestLargeScaleNumberOverlimit(string value) - { - string columnName = "LARGESCALENUMBER"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,37)", columnName)); - await Assert.ThrowsAsync(async () => await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, SqlDecimal.Parse(value))); - } - - /// - /// Validates if driver can handle a small scale Number type correctly - /// - [SkippableTheory] - [InlineData("0")] - [InlineData("4.85")] - [InlineData("-999999999999999999999999999999999999.99")] - [InlineData("999999999999999999999999999999999999.99")] - public async Task TestSmallScaleNumberRange(string value) - { - string columnName = "SMALLSCALENUMBER"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,2)", columnName)); - await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, SqlDecimal.Parse(value)); - } - - /// - /// Validates if driver can error handle when an insert goes beyond a small scale Number type correctly - /// - [SkippableTheory] - [InlineData("-99999999999999999999999999999999999999")] - [InlineData("99999999999999999999999999999999999999")] - public async Task TestSmallScaleNumberOverlimit(string value) - { - string columnName = "SMALLSCALENUMBER"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,2)", columnName)); - await Assert.ThrowsAsync(async () => await ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, SqlDecimal.Parse(value))); - } - - - /// - /// Tests that decimals are rounded as expected. - /// Snowflake allows inserts of scales beyond the data type size, but storage of value will round it up or down - /// - [SkippableTheory] - [InlineData(2.467, 2.47)] - [InlineData(-672.613, -672.61)] - public async Task TestRoundingNumbers(decimal input, decimal output) - { - string columnName = "SMALLSCALENUMBER"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,2)", columnName)); - SqlDecimal value = new SqlDecimal(input); - SqlDecimal returned = new SqlDecimal(output); - await InsertSingleValueAsync(table.TableName, columnName, value.ToString()); - await SelectAndValidateValuesAsync(table.TableName, columnName, returned, 1); - string whereClause = GetWhereClause(columnName, returned); - await DeleteFromTableAsync(table.TableName, whereClause, 1); - } - - /// - /// Validates if driver can handle floating point number type correctly - /// - [SkippableTheory] - [InlineData(0)] - [InlineData(0.2)] - [InlineData(15e-03)] - [InlineData(1.234E+2)] - [InlineData(double.NegativeInfinity)] - [InlineData(double.PositiveInfinity)] - [InlineData(double.NaN)] - [InlineData(double.MinValue)] - [InlineData(double.MaxValue)] - public async Task TestDoubleValuesInsertSelectDelete(double value) - { - string columnName = "DOUBLETYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} DOUBLE", columnName)); - string valueString = ConvertDoubleToString(value); - await InsertSingleValueAsync(table.TableName, columnName, valueString); - await SelectAndValidateValuesAsync(table.TableName, columnName, value, 1); - string whereClause = GetWhereClause(columnName, value); - await DeleteFromTableAsync(table.TableName, whereClause, 1); - } - - /// - /// Validates if driver can handle floating point number type correctly - /// - [SkippableTheory] - [InlineData(0)] - [InlineData(25)] - [InlineData(float.NegativeInfinity)] - [InlineData(float.PositiveInfinity)] - [InlineData(float.NaN)] - // TODO: Solve server issue when non-integer float value is used in where clause. - //[InlineData(25.1)] - //[InlineData(0.2)] - //[InlineData(15e-03)] - //[InlineData(1.234E+2)] - //[InlineData(float.MinValue)] - //[InlineData(float.MaxValue)] - public async Task TestFloatValuesInsertSelectDelete(float value) + public NumericValueTests(ITestOutputHelper output) + : base(output, new SparkTestEnvironment.Factory()) { - string columnName = "FLOATTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} FLOAT", columnName)); - string valueString = ConvertFloatToString(value); - await InsertSingleValueAsync(table.TableName, columnName, valueString); - await SelectAndValidateValuesAsync(table.TableName, columnName, value, 1); - string whereClause = GetWhereClause(columnName, value); - await DeleteFromTableAsync(table.TableName, whereClause, 1); } } } diff --git a/csharp/test/Drivers/Apache/Spark/Resources/SparkData-Databricks.sql b/csharp/test/Drivers/Apache/Spark/Resources/SparkData-Databricks.sql new file mode 100644 index 0000000000..f8f44fc542 --- /dev/null +++ b/csharp/test/Drivers/Apache/Spark/Resources/SparkData-Databricks.sql @@ -0,0 +1,131 @@ + + -- Licensed to the Apache Software Foundation (ASF) under one or more + -- contributor license agreements. See the NOTICE file distributed with + -- this work for additional information regarding copyright ownership. + -- The ASF licenses this file to You under the Apache License, Version 2.0 + -- (the "License"); you may not use this file except in compliance with + -- the License. You may obtain a copy of the License at + + -- http://www.apache.org/licenses/LICENSE-2.0 + + -- Unless required by applicable law or agreed to in writing, software + -- distributed under the License is distributed on an "AS IS" BASIS, + -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + -- See the License for the specific language governing permissions and + -- limitations under the License. + +CREATE OR REPLACE TABLE {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id LONG, + byte BYTE, + short SHORT, + integer INT, + float FLOAT, + number DOUBLE, + decimal NUMERIC(38, 9), + is_active BOOLEAN, + name STRING, + data BINARY, + date DATE, + timestamp TIMESTAMP, + timestamp_ntz TIMESTAMP_NTZ, + timestamp_ltz TIMESTAMP_LTZ, + numbers ARRAY, + person STRUCT < + name STRING, + age LONG + >, + map MAP < + INT, + STRING + >, + varchar VARCHAR(255), + char CHAR(10) +) USING DELTA; + +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id, + byte, short, integer, float, number, decimal, + is_active, + name, data, + date, timestamp, timestamp_ntz, timestamp_ltz, + numbers, + person, + map, + varchar, + char +) +VALUES ( + 1, + 2, 3, 4, 7.89, 1.23, 4.56, + TRUE, + 'John Doe', + -- hex-encoded value `abc123` + X'616263313233', + '2023-09-08', '2023-09-08 12:34:56', '2023-09-08 12:34:56', '2023-09-08 12:34:56+00:00', + ARRAY(1, 2, 3), + STRUCT('John Doe', 30), + MAP(1, 'John Doe'), + 'John Doe', + 'John Doe' +); + +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id, + byte, short, integer, float, number, decimal, + is_active, + name, data, + date, timestamp, timestamp_ntz, timestamp_ltz, + numbers, + person, + map, + varchar, + char +) +VALUES ( + 2, + 127, 32767, 2147483647, 3.4028234663852886e+38, 1.7976931348623157e+308, 9.99999999999999999999999999999999E+28BD, + FALSE, + 'Jane Doe', + -- hex-encoded `def456` + X'646566343536', + '2023-09-09', '2023-09-09 13:45:57', '2023-09-09 13:45:57', '2023-09-09 13:45:57+00:00', + ARRAY(4, 5, 6), + STRUCT('Jane Doe', 40), + MAP(1, 'John Doe'), + 'Jane Doe', + 'Jane Doe' +); + +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id, + byte, short, integer, float, number, decimal, + is_active, + name, data, + date, timestamp, timestamp_ntz, timestamp_ltz, + numbers, + person, + map, + varchar, + char +) +VALUES ( + 3, + -128, -32768, -2147483648, -3.4028234663852886e+38, -1.7976931348623157e+308, -9.99999999999999999999999999999999E+28BD, + FALSE, + 'Jack Doe', + -- hex-encoded `def456` + X'646566343536', + '1556-01-02', '1970-01-01 00:00:00', '1970-01-01 00:00:00', '9999-12-31 23:59:59+00:00', + ARRAY(7, 8, 9), + STRUCT('Jack Doe', 50), + MAP(1, 'John Doe'), + 'Jack Doe', + 'Jack Doe' +); + +UPDATE {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} + SET short = 0 + WHERE id = 3; + +DELETE FROM {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} + WHERE id = 3; diff --git a/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql b/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql index 908ffbb930..9c0a41e0b5 100644 --- a/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql +++ b/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql @@ -14,8 +14,6 @@ -- See the License for the specific language governing permissions and -- limitations under the License. -DROP TABLE IF EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE}; - CREATE TABLE IF NOT EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( id LONG, byte BYTE, @@ -29,8 +27,6 @@ CREATE TABLE IF NOT EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( data BINARY, date DATE, timestamp TIMESTAMP, - timestamp_ntz TIMESTAMP_NTZ, - timestamp_ltz TIMESTAMP_LTZ, numbers ARRAY, person STRUCT < name STRING, @@ -44,90 +40,57 @@ CREATE TABLE IF NOT EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( char CHAR(10) ); -INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( - id, - byte, short, integer, float, number, decimal, - is_active, - name, data, - date, timestamp, timestamp_ntz, timestamp_ltz, - numbers, - person, - map, - varchar, - char -) -VALUES ( +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} + SELECT 1, 2, 3, 4, 7.89, 1.23, 4.56, TRUE, 'John Doe', -- hex-encoded value `abc123` X'616263313233', - '2023-09-08', '2023-09-08 12:34:56', '2023-09-08 12:34:56', '2023-09-08 12:34:56+00:00', + to_date('2023-09-08'), to_timestamp('2023-09-08T12:34:56.000Z'), ARRAY(1, 2, 3), STRUCT('John Doe', 30), MAP(1, 'John Doe'), 'John Doe', 'John Doe' -); +; -INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( - id, - byte, short, integer, float, number, decimal, - is_active, - name, data, - date, timestamp, timestamp_ntz, timestamp_ltz, - numbers, - person, - map, - varchar, - char -) -VALUES ( +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} + SELECT 2, 127, 32767, 2147483647, 3.4028234663852886e+38, 1.7976931348623157e+308, 9.99999999999999999999999999999999E+28BD, FALSE, 'Jane Doe', -- hex-encoded `def456` X'646566343536', - '2023-09-09', '2023-09-09 13:45:57', '2023-09-09 13:45:57', '2023-09-09 13:45:57+00:00', + to_date('2023-09-09'), to_timestamp('2023-09-09T13:45:57.000Z'), ARRAY(4, 5, 6), STRUCT('Jane Doe', 40), MAP(1, 'John Doe'), 'Jane Doe', 'Jane Doe' -); +; -INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( - id, - byte, short, integer, float, number, decimal, - is_active, - name, data, - date, timestamp, timestamp_ntz, timestamp_ltz, - numbers, - person, - map, - varchar, - char -) -VALUES ( +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} + SELECT 3, -128, -32768, -2147483648, -3.4028234663852886e+38, -1.7976931348623157e+308, -9.99999999999999999999999999999999E+28BD, FALSE, 'Jack Doe', -- hex-encoded `def456` X'646566343536', - '1556-01-02', '1970-01-01 00:00:00', '1970-01-01 00:00:00', '9999-12-31 23:59:59+00:00', + to_date('1556-01-02'), to_timestamp('1970-01-01T00:00:00.000Z'), ARRAY(7, 8, 9), STRUCT('Jack Doe', 50), MAP(1, 'John Doe'), 'Jack Doe', 'Jack Doe' -); +; -UPDATE {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} - SET short = 0 - WHERE id = 3; +-- UPDATE {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} +-- SET short = 0 +-- WHERE id = 3; -DELETE FROM {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} - WHERE id = 3; +-- DELETE FROM {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} +-- WHERE id = 3; diff --git a/csharp/test/Drivers/Apache/Spark/Resources/sparkconfig-databricks.json b/csharp/test/Drivers/Apache/Spark/Resources/sparkconfig-databricks.json new file mode 100644 index 0000000000..be8d078d8c --- /dev/null +++ b/csharp/test/Drivers/Apache/Spark/Resources/sparkconfig-databricks.json @@ -0,0 +1,16 @@ +{ + "hostName": "", + "path": "", + "token": "", + "auth_type": "token", + "type": "databricks", + "data_type_conv": "none", + "query": "", + "expectedResults": 0, + "metadata": { + "catalog": "hive_metastore", + "schema": "default", + "table": "", + "expectedColumnCount": 19 + } +} diff --git a/csharp/test/Drivers/Apache/Spark/Resources/sparkconfig-http.json b/csharp/test/Drivers/Apache/Spark/Resources/sparkconfig-http.json new file mode 100644 index 0000000000..c2899990bd --- /dev/null +++ b/csharp/test/Drivers/Apache/Spark/Resources/sparkconfig-http.json @@ -0,0 +1,15 @@ +{ + "uri": "http://:10001/cliservice", + "auth_type": "basic", + "username": "", + "password": "", + "type": "http", + "data_type_conv": "none", + "query": "", + "expectedResults": 0, + "metadata": { + "schema": "default", + "table": "", + "expectedColumnCount": 17 + } +} diff --git a/csharp/test/Drivers/Apache/Spark/Resources/sparkconfig.json b/csharp/test/Drivers/Apache/Spark/Resources/sparkconfig.json deleted file mode 100644 index fb76fa164c..0000000000 --- a/csharp/test/Drivers/Apache/Spark/Resources/sparkconfig.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "hostName": "", - "path": "", - "token": "", - "query": "", - "expectedResults": 0, - "metadata": { - "catalog": "", - "schema": "", - "table": "", - "expectedColumnCount": 0 - } -} diff --git a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs new file mode 100644 index 0000000000..34e971bd8f --- /dev/null +++ b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs @@ -0,0 +1,322 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Net; +using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Thrift.Transport; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark +{ + /// + /// Class for testing the Spark ADBC connection tests. + /// + public class SparkConnectionTest : TestBase + { + public SparkConnectionTest(ITestOutputHelper? outputHelper) : base(outputHelper, new SparkTestEnvironment.Factory()) + { + Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); + } + + /// + /// Validates database can detect invalid connection parameter combinations. + /// + [SkippableTheory] + [ClassData(typeof(InvalidConnectionParametersTestData))] + internal void CanDetectConnectionParameterErrors(ParametersWithExceptions test) + { + AdbcDriver driver = NewDriver; + AdbcDatabase database = driver.Open(test.Parameters); + Exception exeption = Assert.Throws(test.ExceptionType, () => database.Connect(test.Parameters)); + OutputHelper?.WriteLine(exeption.Message); + } + + /// + /// Tests connection timeout to establish a session with the backend. + /// + /// The timeout (in ms) + /// The exception type to expect (if any) + /// An alternate exception that may occur (if any) + [SkippableTheory] + [InlineData(0, null, null)] + [InlineData(1, typeof(TimeoutException), typeof(TTransportException))] + [InlineData(10, typeof(TimeoutException), typeof(TTransportException))] + [InlineData(30000, null, null)] + [InlineData(null, null, null)] + public void ConnectionTimeoutTest(int? connectTimeoutMilliseconds, Type? exceptionType, Type? alternateExceptionType) + { + SparkTestConfiguration testConfiguration = (SparkTestConfiguration)TestConfiguration.Clone(); + + if (connectTimeoutMilliseconds.HasValue) + testConfiguration.ConnectTimeoutMilliseconds = connectTimeoutMilliseconds.Value.ToString(); + + OutputHelper?.WriteLine($"ConnectTimeoutMilliseconds: {testConfiguration.ConnectTimeoutMilliseconds}. ShouldSucceed: {exceptionType == null}"); + + try + { + NewConnection(testConfiguration); + } + catch(AggregateException aex) + { + if (exceptionType != null) + { + if (alternateExceptionType != null && aex.InnerException?.GetType() != exceptionType) + { + if (aex.InnerException?.GetType() == typeof(HiveServer2Exception)) + { + // a TTransportException is inside a HiveServer2Exception + Assert.IsType(alternateExceptionType, aex.InnerException!.InnerException); + } + else + { + throw; + } + } + else + { + Assert.IsType(exceptionType, aex.InnerException); + } + } + else + { + throw; + } + } + } + + /// + /// Tests the various metadata calls on a SparkConnection + /// + /// + [SkippableTheory] + [ClassData(typeof(MetadataTimeoutTestData))] + internal void MetadataTimeoutTest(MetadataWithExceptions metadataWithException) + { + SparkTestConfiguration testConfiguration = (SparkTestConfiguration)TestConfiguration.Clone(); + + if (metadataWithException.QueryTimeoutSeconds.HasValue) + testConfiguration.QueryTimeoutSeconds = metadataWithException.QueryTimeoutSeconds.Value.ToString(); + + OutputHelper?.WriteLine($"Action: {metadataWithException.ActionName}. QueryTimeoutSeconds: {testConfiguration.QueryTimeoutSeconds}. ShouldSucceed: {metadataWithException.ExceptionType == null}"); + + try + { + metadataWithException.MetadataAction(testConfiguration); + } + catch (Exception ex) when (ApacheUtility.ContainsException(ex, metadataWithException.ExceptionType, out Exception? containedException)) + { + Assert.IsType(metadataWithException.ExceptionType!, containedException); + } + catch (Exception ex) when (ApacheUtility.ContainsException(ex, metadataWithException.AlternateExceptionType, out Exception? containedException)) + { + Assert.IsType(metadataWithException.AlternateExceptionType!, containedException); + } + } + + /// + /// Data type used for metadata timeout tests. + /// + internal class MetadataWithExceptions + { + public MetadataWithExceptions(int? queryTimeoutSeconds, string actionName, Action action, Type? exceptionType, Type? alternateExceptionType) + { + QueryTimeoutSeconds = queryTimeoutSeconds; + ActionName = actionName; + MetadataAction = action; + ExceptionType = exceptionType; + AlternateExceptionType = alternateExceptionType; + } + + /// + /// If null, uses the default timeout. + /// + public int? QueryTimeoutSeconds { get; } + + public string ActionName { get; } + + /// + /// If null, expected to succeed. + /// + public Type? ExceptionType { get; } + + /// + /// Sometimes you can expect one but may get another. + /// For example, on GetObjectsAll, sometimes a TTransportException is expected but a TaskCanceledException is received during the test. + /// + public Type? AlternateExceptionType { get; } + + /// + /// The metadata action to perform. + /// + public Action MetadataAction { get; } + } + + /// + /// Used for testing timeouts on metadata calls. + /// + internal class MetadataTimeoutTestData : TheoryData + { + public MetadataTimeoutTestData() + { + SparkConnectionTest sparkConnectionTest = new SparkConnectionTest(null); + + Action getObjectsAll = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.All, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Table, null, null); + }; + + Action getObjectsCatalogs = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.Catalogs, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); + }; + + Action getObjectsDbSchemas = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.DbSchemas, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); + }; + + Action getObjectsTables = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetObjects(AdbcConnection.GetObjectsDepth.Tables, testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Schema, null, null); + }; + + AddAction("getObjectsAll", getObjectsAll, new List() { null, typeof(TimeoutException), null, null, null } ); + AddAction("getObjectsCatalogs", getObjectsCatalogs); + AddAction("getObjectsDbSchemas", getObjectsDbSchemas); + AddAction("getObjectsTables", getObjectsTables); + + Action getTableTypes = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetTableTypes(); + }; + + AddAction("getTableTypes", getTableTypes); + + Action getTableSchema = (testConfiguration) => + { + AdbcConnection cn = sparkConnectionTest.NewConnection(testConfiguration); + cn.GetTableSchema(testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, testConfiguration.Metadata.Table); + }; + + AddAction("getTableSchema", getTableSchema); + } + + /// + /// Adds the action with the default timeouts. + /// + /// The friendly name of the action. + /// The action to perform. + /// Optional list of alternate exceptions that are possible. Must have 5 items if present. + private void AddAction(string name, Action action, List? alternateExceptions = null) + { + List expectedExceptions = new List() + { + null, // QueryTimeout = 0 + typeof(TTransportException), // QueryTimeout = 1 + typeof(TimeoutException), // QueryTimeout = 10 + null, // QueryTimeout = default + null // QueryTimeout = 300 + }; + + AddAction(name, action, expectedExceptions, alternateExceptions); + } + + /// + /// Adds the action with the default timeouts. + /// + /// The action to perform. + /// The expected exceptions. + /// + /// For List the position is based on the behavior when: + /// [0] QueryTimeout = 0 + /// [1] QueryTimeout = 1 + /// [2] QueryTimeout = 10 + /// [3] QueryTimeout = default + /// [4] QueryTimeout = 300 + /// + private void AddAction(string name, Action action, List expectedExceptions, List? alternateExceptions) + { + Assert.True(expectedExceptions.Count == 5); + + if (alternateExceptions != null) + { + Assert.True(alternateExceptions.Count == 5); + } + + Add(new(0, name, action, expectedExceptions[0], alternateExceptions?[0])); + Add(new(1, name, action, expectedExceptions[1], alternateExceptions?[1])); + Add(new(10, name, action, expectedExceptions[2], alternateExceptions?[2])); + Add(new(null, name, action, expectedExceptions[3], alternateExceptions?[3])); + Add(new(300, name, action, expectedExceptions[4], alternateExceptions?[4])); + } + } + + internal class ParametersWithExceptions + { + public ParametersWithExceptions(Dictionary parameters, Type exceptionType) + { + Parameters = parameters; + ExceptionType = exceptionType; + } + + public IReadOnlyDictionary Parameters { get; } + public Type ExceptionType { get; } + } + + internal class InvalidConnectionParametersTestData : TheoryData + { + public InvalidConnectionParametersTestData() + { + Add(new([], typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = " " }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = "xxx" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Standard }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = " " }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "invalid!server.com" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "http://valid.server.com" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"unknown_auth_type" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Basic}" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Token}" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Basic}", [SparkParameters.Token] = "abcdef" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.AuthType] = $"{SparkAuthTypeConstants.Token}", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Password] = "myPassword" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = "-1" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = IPEndPoint.MinPort.ToString(CultureInfo.InvariantCulture) }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = (IPEndPoint.MaxPort + 1).ToString(CultureInfo.InvariantCulture) }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "httpxxz://hostname.com" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "http-//hostname.com" }, typeof(UriFormatException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = "httpxxz://hostname.com:1234567890" }, typeof(UriFormatException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = ((long)int.MaxValue + 1).ToString() }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = "non-numeric" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.Type] = SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", [AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", [SparkParameters.ConnectTimeoutMilliseconds] = "" }, typeof(ArgumentOutOfRangeException))); + } + } + } +} diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestBase.cs b/csharp/test/Drivers/Apache/Spark/SparkTestBase.cs deleted file mode 100644 index faaf180bcc..0000000000 --- a/csharp/test/Drivers/Apache/Spark/SparkTestBase.cs +++ /dev/null @@ -1,66 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -using System; -using System.Collections.Generic; -using System.Threading.Tasks; -using Apache.Arrow.Adbc.Drivers.Apache.Spark; -using Xunit.Abstractions; - -namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark -{ - public class SparkTestBase : TestBase - { - public SparkTestBase(ITestOutputHelper? outputHelper) : base(outputHelper) { } - - protected override string TestConfigVariable => "SPARK_TEST_CONFIG_FILE"; - - protected override string SqlDataResourceLocation => "Spark/Resources/SparkData.sql"; - - protected override AdbcDriver NewDriver => new SparkDriver(); - - protected override async ValueTask NewTemporaryTableAsync(AdbcStatement statement, string columns) { - string tableName = NewTableName(); - // Note: Databricks/Spark doesn't support TEMPORARY table. - string sqlUpdate = string.Format("CREATE TABLE {0} ({1})", tableName, columns); - OutputHelper?.WriteLine(sqlUpdate); - return await TemporaryTable.NewTemporaryTableAsync(statement, tableName, sqlUpdate); - } - - protected override string Delimiter => "`"; - - protected override Dictionary GetDriverParameters(SparkTestConfiguration testConfiguration) - { - Dictionary parameters = new(StringComparer.OrdinalIgnoreCase); - - if (!string.IsNullOrEmpty(testConfiguration.HostName)) - { - parameters.Add(SparkParameters.HostName, testConfiguration.HostName!); - } - if (!string.IsNullOrEmpty(testConfiguration.Path)) - { - parameters.Add(SparkParameters.Path, testConfiguration.Path!); - } - if (!string.IsNullOrEmpty(testConfiguration.Token)) - { - parameters.Add(SparkParameters.Token, testConfiguration.Token!); - } - - return parameters; - } - } -} diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs b/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs index 90dd7150ba..5ada5abeb3 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs @@ -15,10 +15,15 @@ * limitations under the License. */ +using System.Text.Json.Serialization; + namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark { public class SparkTestConfiguration : ApacheTestConfiguration { + [JsonPropertyName("token"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string Token { get; set; } = string.Empty; + } } diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs new file mode 100644 index 0000000000..16a5501118 --- /dev/null +++ b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs @@ -0,0 +1,291 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Text; +using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2; +using Apache.Arrow.Types; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark +{ + public class SparkTestEnvironment : HiveServer2TestEnvironment + { + public class Factory : Factory + { + public override SparkTestEnvironment Create(Func getConnection) => new(getConnection); + } + + private SparkTestEnvironment(Func getConnection) : base(getConnection) { } + + public override string TestConfigVariable => "SPARK_TEST_CONFIG_FILE"; + + public override string SqlDataResourceLocation => ServerType == SparkServerType.Databricks + ? "Spark/Resources/SparkData-Databricks.sql" + : "Spark/Resources/SparkData.sql"; + + public override int ExpectedColumnCount => ServerType == SparkServerType.Databricks ? 19 : 17; + + public override AdbcDriver CreateNewDriver() => new SparkDriver(); + + public override string GetCreateTemporaryTableStatement(string tableName, string columns) + { + return string.Format("CREATE TABLE {0} ({1})", tableName, columns); + } + + public override string Delimiter => "`"; + + public override Dictionary GetDriverParameters(SparkTestConfiguration testConfiguration) + { + Dictionary parameters = new(StringComparer.OrdinalIgnoreCase); + + if (!string.IsNullOrEmpty(testConfiguration.HostName)) + { + parameters.Add(SparkParameters.HostName, testConfiguration.HostName!); + } + if (!string.IsNullOrEmpty(testConfiguration.Uri)) + { + parameters.Add(AdbcOptions.Uri, testConfiguration.Uri!); + } + if (!string.IsNullOrEmpty(testConfiguration.Port)) + { + parameters.Add(SparkParameters.Port, testConfiguration.Port!); + } + if (!string.IsNullOrEmpty(testConfiguration.Path)) + { + parameters.Add(SparkParameters.Path, testConfiguration.Path!); + } + if (!string.IsNullOrEmpty(testConfiguration.Token)) + { + parameters.Add(SparkParameters.Token, testConfiguration.Token!); + } + if (!string.IsNullOrEmpty(testConfiguration.Username)) + { + parameters.Add(AdbcOptions.Username, testConfiguration.Username!); + } + if (!string.IsNullOrEmpty(testConfiguration.Password)) + { + parameters.Add(AdbcOptions.Password, testConfiguration.Password!); + } + if (!string.IsNullOrEmpty(testConfiguration.AuthType)) + { + parameters.Add(SparkParameters.AuthType, testConfiguration.AuthType!); + } + if (!string.IsNullOrEmpty(testConfiguration.Type)) + { + parameters.Add(SparkParameters.Type, testConfiguration.Type!); + } + if (!string.IsNullOrEmpty(testConfiguration.DataTypeConversion)) + { + parameters.Add(SparkParameters.DataTypeConv, testConfiguration.DataTypeConversion!); + } + if (!string.IsNullOrEmpty(testConfiguration.TlsOptions)) + { + parameters.Add(SparkParameters.TLSOptions, testConfiguration.TlsOptions!); + } + if (!string.IsNullOrEmpty(testConfiguration.BatchSize)) + { + parameters.Add(ApacheParameters.BatchSize, testConfiguration.BatchSize!); + } + if (!string.IsNullOrEmpty(testConfiguration.PollTimeMilliseconds)) + { + parameters.Add(ApacheParameters.PollTimeMilliseconds, testConfiguration.PollTimeMilliseconds!); + } + if (!string.IsNullOrEmpty(testConfiguration.ConnectTimeoutMilliseconds)) + { + parameters.Add(SparkParameters.ConnectTimeoutMilliseconds, testConfiguration.ConnectTimeoutMilliseconds!); + } + if (!string.IsNullOrEmpty(testConfiguration.QueryTimeoutSeconds)) + { + parameters.Add(ApacheParameters.QueryTimeoutSeconds, testConfiguration.QueryTimeoutSeconds!); + } + + return parameters; + } + + internal SparkServerType ServerType => ((SparkConnection)Connection).ServerType; + + public override string VendorVersion => ((HiveServer2Connection)Connection).VendorVersion; + + public override bool SupportsDelete => ServerType == SparkServerType.Databricks; + + public override bool SupportsUpdate => ServerType == SparkServerType.Databricks; + + public override bool SupportCatalogName => ServerType == SparkServerType.Databricks; + + public override bool ValidateAffectedRows => ServerType == SparkServerType.Databricks; + + public override string GetInsertStatement(string tableName, string columnName, string? value) => + string.Format("INSERT INTO {0} ({1}) SELECT {2};", tableName, columnName, value ?? "NULL"); + + public override SampleDataBuilder GetSampleDataBuilder() + { + SampleDataBuilder sampleDataBuilder = new(); + bool dataTypeIsFloat = ServerType == SparkServerType.Databricks || DataTypeConversion.HasFlag(DataTypeConversion.Scalar); + Type floatNetType = dataTypeIsFloat ? typeof(float) : typeof(double); + Type floatArrowType = dataTypeIsFloat ? typeof(FloatType) : typeof(DoubleType); + object floatValue; + if (dataTypeIsFloat) + floatValue = 1f; + else + floatValue = 1d; + + // standard values + sampleDataBuilder.Samples.Add( + new SampleData() + { + Query = "SELECT " + + "CAST(1 as BIGINT) as id, " + + "CAST(2 as INTEGER) as int, " + + "CAST(1 as FLOAT) as number_float, " + + "CAST(4.56 as DOUBLE) as number_double, " + + "4.56BD as decimal, " + + "9.9999999999999999999999999999999999999BD as big_decimal, " + + "CAST(True as BOOLEAN) as is_active, " + + "'John Doe' as name, " + + "X'616263313233' as data, " + + "DATE '2023-09-08' as date, " + + "TIMESTAMP '2023-09-08 12:34:56+00:00' as timestamp, " + + "INTERVAL 178956969 YEAR 11 MONTH as interval, " + + "ARRAY(1, 2, 3) as numbers, " + + "STRUCT('John Doe' as name, 30 as age) as person," + + "MAP('name', CAST('Jane Doe' AS STRING), 'age', CAST(29 AS INT)) as map", + ExpectedValues = + [ + new("id", typeof(long), typeof(Int64Type), 1L), + new("int", typeof(int), typeof(Int32Type), 2), + new("number_float", floatNetType, floatArrowType, floatValue), + new("number_double", typeof(double), typeof(DoubleType), 4.56d), + new("decimal", typeof(SqlDecimal), typeof(Decimal128Type), SqlDecimal.Parse("4.56")), + new("big_decimal", typeof(SqlDecimal), typeof(Decimal128Type), SqlDecimal.Parse("9.9999999999999999999999999999999999999")), + new("is_active", typeof(bool), typeof(BooleanType), true), + new("name", typeof(string), typeof(StringType), "John Doe"), + new("data", typeof(byte[]), typeof(BinaryType), UTF8Encoding.UTF8.GetBytes("abc123")), + new("date", typeof(DateTime), typeof(Date32Type), new DateTime(2023, 9, 8)), + new("timestamp", typeof(DateTimeOffset), typeof(TimestampType), new DateTimeOffset(new DateTime(2023, 9, 8, 12, 34, 56), TimeSpan.Zero)), + new("interval", typeof(string), typeof(StringType), "178956969-11"), + new("numbers", typeof(string), typeof(StringType), "[1,2,3]"), + new("person", typeof(string), typeof(StringType), """{"name":"John Doe","age":30}"""), + new("map", typeof(string), typeof(StringType), """{"age":"29","name":"Jane Doe"}""") // This is unexpected JSON. Expecting 29 to be a numeric and not string. + ] + }); + + sampleDataBuilder.Samples.Add( + new SampleData() + { + Query = "SELECT " + + "CAST(NULL as BIGINT) as id, " + + "CAST(NULL as INTEGER) as int, " + + "CAST(NULL as FLOAT) as number_float, " + + "CAST(NULL as DOUBLE) as number_double, " + + "CAST(NULL as DECIMAL(38,2)) as decimal, " + + "CAST(NULL as BOOLEAN) as is_active, " + + "CAST(NULL as STRING) as name, " + + "CAST(NULL as BINARY) as data, " + + "CAST(NULL as DATE) as date, " + + "CAST(NULL as TIMESTAMP) as timestamp," + + "CAST(NULL as MAP) as map, " + + "CAST(NULL as ARRAY) as numbers, " + + "CAST(NULL as STRUCT) as person, " + + "MAP(CAST('EMPTY' as STRING), CAST(NULL as INTEGER)) as map_null, " + + "ARRAY(NULL,NULL,NULL) as numbers_null, " + + "STRUCT(CAST(NULL as STRING), CAST(NULL as INTEGER)) as person_null", + ExpectedValues = + [ + new("id", typeof(long), typeof(Int64Type), null), + new("int", typeof(int), typeof(Int32Type), null), + new("number_float", floatNetType, floatArrowType, null), + new("number_double", typeof(double), typeof(DoubleType), null), + new("decimal", typeof(SqlDecimal), typeof(Decimal128Type), null), + new("is_active", typeof(bool), typeof(BooleanType), null), + new("name", typeof(string), typeof(StringType), null), + new("data", typeof(byte[]), typeof(BinaryType), null), + new("date", typeof(DateTime), typeof(Date32Type), null), + new("timestamp", typeof(DateTimeOffset), typeof(TimestampType), null), + new("map", typeof(string), typeof(StringType), null), + new("numbers", typeof(string), typeof(StringType), null), + new("person", typeof(string), typeof(StringType), null), + new("map_null", typeof(string), typeof(StringType), """{"EMPTY":null}"""), + new("numbers_null", typeof(string), typeof(StringType), """[null,null,null]"""), + new("person_null", typeof(string), typeof(StringType), """{"col1":null,"col2":null}"""), + ] + }); + + // complex struct + sampleDataBuilder.Samples.Add( + new SampleData() + { + Query = "SELECT " + + "STRUCT(" + + "\"Iron Man\" as name," + + "\"Avengers\" as team," + + "ARRAY(\"Genius\", \"Billionaire\", \"Playboy\", \"Philanthropist\") as powers," + + "ARRAY(" + + " STRUCT(" + + " \"Captain America\" as name, " + + " \"Avengers\" as team, " + + " ARRAY(\"Super Soldier Serum\", \"Vibranium Shield\") as powers, " + + " ARRAY(" + + " STRUCT(" + + " \"Thanos\" as name, " + + " \"Black Order\" as team, " + + " ARRAY(\"Infinity Gauntlet\", \"Super Strength\", \"Teleportation\") as powers, " + + " ARRAY(" + + " STRUCT(" + + " \"Loki\" as name, " + + " \"Asgard\" as team, " + + " ARRAY(\"Magic\", \"Shapeshifting\", \"Trickery\") as powers " + + " )" + + " ) as allies" + + " )" + + " ) as enemies" + + " )," + + " STRUCT(" + + " \"Spider-Man\" as name, " + + " \"Avengers\" as team, " + + " ARRAY(\"Spider-Sense\", \"Web-Shooting\", \"Wall-Crawling\") as powers, " + + " ARRAY(" + + " STRUCT(" + + " \"Green Goblin\" as name, " + + " \"Sinister Six\" as team, " + + " ARRAY(\"Glider\", \"Pumpkin Bombs\", \"Super Strength\") as powers, " + + " ARRAY(" + + " STRUCT(" + + " \"Doctor Octopus\" as name, " + + " \"Sinister Six\" as team, " + + " ARRAY(\"Mechanical Arms\", \"Genius\", \"Madness\") as powers " + + " )" + + " ) as allies" + + " )" + + " ) as enemies" + + " )" + + " ) as friends" + + ") as iron_man", + ExpectedValues = + [ + new("iron_man", typeof(string), typeof(StringType), "{\"name\":\"Iron Man\",\"team\":\"Avengers\",\"powers\":[\"Genius\",\"Billionaire\",\"Playboy\",\"Philanthropist\"],\"friends\":[{\"name\":\"Captain America\",\"team\":\"Avengers\",\"powers\":[\"Super Soldier Serum\",\"Vibranium Shield\"],\"enemies\":[{\"name\":\"Thanos\",\"team\":\"Black Order\",\"powers\":[\"Infinity Gauntlet\",\"Super Strength\",\"Teleportation\"],\"allies\":[{\"name\":\"Loki\",\"team\":\"Asgard\",\"powers\":[\"Magic\",\"Shapeshifting\",\"Trickery\"]}]}]},{\"name\":\"Spider-Man\",\"team\":\"Avengers\",\"powers\":[\"Spider-Sense\",\"Web-Shooting\",\"Wall-Crawling\"],\"enemies\":[{\"name\":\"Green Goblin\",\"team\":\"Sinister Six\",\"powers\":[\"Glider\",\"Pumpkin Bombs\",\"Super Strength\"],\"allies\":[{\"name\":\"Doctor Octopus\",\"team\":\"Sinister Six\",\"powers\":[\"Mechanical Arms\",\"Genius\",\"Madness\"]}]}]}]}") + ] + }); + + return sampleDataBuilder; + } + } +} diff --git a/csharp/test/Drivers/Apache/Spark/SqlTypeNameParserTests.cs b/csharp/test/Drivers/Apache/Spark/SqlTypeNameParserTests.cs new file mode 100644 index 0000000000..8ec4f5390f --- /dev/null +++ b/csharp/test/Drivers/Apache/Spark/SqlTypeNameParserTests.cs @@ -0,0 +1,312 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark +{ + public class SqlTypeNameParserTests(ITestOutputHelper outputHelper) + { + private readonly ITestOutputHelper _outputHelper = outputHelper; + + [Theory()] + [InlineData("ARRAY", "ARRAY")] + [InlineData("ARRAY < INT >", "ARRAY")] + [InlineData(" ARRAY < ARRAY < INT > > ", "ARRAY")] + [InlineData("ARRAY", "ARRAY")] + [InlineData("DATE", "DATE")] + [InlineData("dec(15)", "DECIMAL")] + [InlineData("numeric", "DECIMAL")] + [InlineData("STRUCT", "STRUCT")] + [InlineData("STRUCT< F1 INT >", "STRUCT")] + [InlineData("STRUCT < F1: ARRAY < INT > > ", "STRUCT")] + [InlineData("STRUCT>", "STRUCT")] + [InlineData("MAP", "MAP")] + [InlineData("MAP< INT , VARCHAR(255) >", "MAP")] + [InlineData("MAP < ARRAY < INT >, INT > ", "MAP")] + [InlineData("TIMESTAMP", "TIMESTAMP")] + [InlineData("TIMESTAMP_LTZ", "TIMESTAMP")] + [InlineData("TIMESTAMP_NTZ", "TIMESTAMP")] + internal void CanParseAnyType(string testTypeName, string expectedBaseTypeName) + { + SqlTypeNameParserResult result = SqlTypeNameParser.Parse(testTypeName); + Assert.NotNull(result); + Assert.Equal(testTypeName, result.TypeName); + Assert.Equal(expectedBaseTypeName, result.BaseTypeName); + } + + [Theory()] + [InlineData("BIGINT", "BIGINT")] + [InlineData("BINARY", "BINARY")] + [InlineData("BOOLEAN", "BOOLEAN")] + [InlineData("DATE", "DATE")] + [InlineData("DOUBLE", "DOUBLE")] + [InlineData("FLOAT", "FLOAT")] + [InlineData("SMALLINT", "SMALLINT")] + [InlineData("TINYINT", "TINYINT")] + internal void CanParseSimpleTypeName(string testTypeName, string expectedBaseTypeName) + { + Assert.True(SqlTypeNameParser.TryParse(testTypeName, out SqlTypeNameParserResult? result)); + Assert.NotNull(result); + Assert.Equal(expectedBaseTypeName, result.BaseTypeName); + } + + [Theory()] + [InlineData("INTERVAL YEAR", "INTERVAL")] + [InlineData("INTERVAL MONTH", "INTERVAL")] + [InlineData("INTERVAL DAY", "INTERVAL")] + [InlineData("INTERVAL HOUR", "INTERVAL")] + [InlineData("INTERVAL MINUTE", "INTERVAL")] + [InlineData("INTERVAL SECOND", "INTERVAL")] + [InlineData("INTERVAL YEAR TO MONTH", "INTERVAL")] + [InlineData("INTERVAL DAY TO HOUR", "INTERVAL")] + [InlineData("INTERVAL DAY TO MINUTE", "INTERVAL")] + [InlineData("INTERVAL DAY TO SECOND", "INTERVAL")] + [InlineData("INTERVAL HOUR TO MINUTE", "INTERVAL")] + [InlineData("INTERVAL HOUR TO SECOND", "INTERVAL")] + [InlineData("INTERVAL MINUTE TO SECOND", "INTERVAL")] + internal void CanParseInterval(string testTypeName, string expectedBaseTypeName) + { + Assert.True(SqlTypeNameParser.TryParse(testTypeName, out SqlTypeNameParserResult? result)); + Assert.NotNull(result); + Assert.Equal(expectedBaseTypeName, result.BaseTypeName); + } + + [Theory()] + [MemberData(nameof(GenerateCharTestData), "CHAR")] + [MemberData(nameof(GenerateCharTestData), "NCHAR")] + [MemberData(nameof(GenerateCharTestData), "CHaR")] + internal void CanParseChar(string testTypeName, SqlCharVarcharParserResult expected) + { + _outputHelper.WriteLine(testTypeName); + Assert.True(SqlCharTypeParser.Default.TryParse(testTypeName, out SqlCharVarcharParserResult? result)); + Assert.NotNull(result); + Assert.Equal(expected, result); + } + + [Theory()] + [MemberData(nameof(GenerateVarcharTestData), "VARCHAR")] + [MemberData(nameof(GenerateVarcharTestData), "LONGVARCHAR")] + [MemberData(nameof(GenerateVarcharTestData), "NVARCHAR")] + [MemberData(nameof(GenerateVarcharTestData), "LONGNVARCHAR")] + [MemberData(nameof(GenerateVarcharTestData), "VaRCHaR")] + internal void CanParseVarchar(string testTypeName, SqlCharVarcharParserResult expected) + { + _outputHelper.WriteLine(testTypeName); + Assert.True(SqlVarcharTypeParser.Default.TryParse(testTypeName, out SqlCharVarcharParserResult? result)); + Assert.NotNull(result); + Assert.Equal(expected, result); + } + + [Theory()] + [MemberData(nameof(GenerateDecimalTestData), "DECIMAL")] + [MemberData(nameof(GenerateDecimalTestData), "DEC")] + [MemberData(nameof(GenerateDecimalTestData), "NUMERIC")] + [MemberData(nameof(GenerateDecimalTestData), "DeCiMaL")] + internal void CanParseDecimal(string testTypeName, SqlDecimalParserResult expected) + { + _outputHelper.WriteLine(testTypeName); + Assert.True(SqlDecimalTypeParser.Default.TryParse(testTypeName, out SqlDecimalParserResult? result)); + Assert.NotNull(result); + Assert.Equal(expected.TypeName, result.TypeName); + Assert.Equal(expected.BaseTypeName, result.BaseTypeName); + // Note: Decimal128Type does not override Equals/GetHashCode + Assert.Equal(expected.Decimal128Type.Name, result.Decimal128Type.Name); + Assert.Equal(expected.Decimal128Type.Precision, result.Decimal128Type.Precision); + Assert.Equal(expected.Decimal128Type.Scale, result.Decimal128Type.Scale); + } + + [Theory()] + [InlineData("INT")] + [InlineData("INTEGER")] + [InlineData(" INT ")] + [InlineData(" INTEGER ")] + [InlineData(" iNTeGeR ")] + internal void CanParseInteger(string testTypeName) + { + string baseTypeName = SqlIntegerTypeParser.Default.BaseTypeName; + SqlTypeNameParserResult expected = new(testTypeName, baseTypeName); + _outputHelper.WriteLine(testTypeName); + Assert.True(SqlIntegerTypeParser.Default.TryParse(testTypeName, out SqlTypeNameParserResult? result)); + Assert.NotNull(result); + Assert.Equal(expected, result); + } + + [Theory()] + [InlineData("TIMESTAMP")] + [InlineData("TIMESTAMP_LTZ")] + [InlineData("TIMESTAMP_NTZ")] + [InlineData("TiMeSTaMP")] + internal void CanParseTimestamp(string testTypeName) + { + string baseTypeName = SqlTimestampTypeParser.Default.BaseTypeName; + SqlTypeNameParserResult expected = new(testTypeName, baseTypeName); + _outputHelper.WriteLine(testTypeName); + Assert.True(SqlTimestampTypeParser.Default.TryParse(testTypeName, out SqlTypeNameParserResult? result)); + Assert.NotNull(result); + Assert.Equal(expected, result); + } + + [Theory()] + [InlineData("ARRAY")] + [InlineData("ARRAY < INT >")] + [InlineData(" ARRAY < ARRAY < INT > > ")] + [InlineData("ARRAY")] + [InlineData("aRRaY")] + internal void CanParseArray(string testTypeName) + { + string baseTypeName = SqlArrayTypeParser.Default.BaseTypeName; + SqlTypeNameParserResult expected = new(testTypeName, baseTypeName); + _outputHelper.WriteLine(testTypeName); + Assert.True(SqlArrayTypeParser.Default.TryParse(testTypeName, out SqlTypeNameParserResult? result)); + Assert.NotNull(result); + Assert.Equal(expected, result); + } + + [Theory()] + [InlineData("MAP")] + [InlineData("MAP< INT , VARCHAR(255) >")] + [InlineData("MAP < ARRAY < INT >, INT > ")] + [InlineData("MaP")] + internal void CanParseMap(string testTypeName) + { + string baseTypeName = SqlMapTypeParser.Default.BaseTypeName; + SqlTypeNameParserResult expected = new(testTypeName, baseTypeName); + _outputHelper.WriteLine(testTypeName); + Assert.True(SqlMapTypeParser.Default.TryParse(testTypeName, out SqlTypeNameParserResult? result)); + Assert.NotNull(result); + Assert.Equal(expected, result); + } + + [Theory()] + [InlineData("STRUCT")] + [InlineData("STRUCT< F1 INT >")] + [InlineData("STRUCT < F1: ARRAY < INT > > ")] + [InlineData("STRUCT>")] + [InlineData("STRuCT")] + internal void CanParseStruct(string testTypeName) + { + string baseTypeName = SqlStructTypeParser.Default.BaseTypeName; + SqlTypeNameParserResult expected = new(testTypeName, baseTypeName); + _outputHelper.WriteLine(testTypeName); + Assert.True(SqlStructTypeParser.Default.TryParse(testTypeName, out SqlTypeNameParserResult? result)); + Assert.NotNull(result); + Assert.Equal(expected, result); + } + + [Theory()] + [InlineData("ARRAY")] + [InlineData("MAP")] + [InlineData("STRUCT")] + [InlineData("ARRAY<")] + [InlineData("MAP<")] + [InlineData("STRUCT<")] + [InlineData("ARRAY>")] + [InlineData("MAP>")] + [InlineData("STRUCT>")] + [InlineData("INTERVAL")] + [InlineData("TIMESTAMP_ZZZ")] + internal void CannotParseUnexpectedTypeName(string testTypeName) + { + Assert.False(SqlTypeNameParser.TryParse(testTypeName, out _), $"Expecting type {testTypeName} to fail to parse."); + } + + [Fact()] + internal void CanDetectInvalidReturnType() + { + Func testCode = () => SqlTypeNameParser.Parse("INTEGER", (int)SparkConnection.ColumnTypeId.INTEGER); + _outputHelper.WriteLine(Assert.Throws(testCode).Message); + } + + public static IEnumerable GenerateCharTestData(string typeName) + { + int?[] lengths = [1, 10, int.MaxValue,]; + string[] spaces = ["", " ", "\t"]; + string baseTypeName = SqlCharTypeParser.Default.BaseTypeName; + foreach (int? length in lengths) + { + foreach (string leadingSpace in spaces) + { + foreach (string trailingSpace in spaces) + { + string clause = length == null ? "" : $"{leadingSpace}({leadingSpace}{length}{trailingSpace})"; + string testTypeName = $"{leadingSpace}{typeName}{clause}{trailingSpace}"; + SqlCharVarcharParserResult expectedResult = new(testTypeName, baseTypeName, length ?? int.MaxValue); + yield return new object[] { testTypeName, expectedResult }; + } + } + } + } + + public static IEnumerable GenerateVarcharTestData(string typeName) + { + int?[] lengths = [null, 1, 10, int.MaxValue,]; + string[] spaces = ["", " ", "\t"]; + string baseTypeName = SqlVarcharTypeParser.Default.BaseTypeName; + foreach (int? length in lengths) + { + foreach (string leadingSpace in spaces) + { + foreach (string trailingSpace in spaces) + { + string clause = length == null ? "" : $"{leadingSpace}({leadingSpace}{length}{trailingSpace})"; + string testTypeName = $"{leadingSpace}{typeName}{clause}{trailingSpace}"; + SqlCharVarcharParserResult expectedResult = new(testTypeName, baseTypeName, length ?? int.MaxValue); + yield return new object[] { testTypeName, expectedResult }; + } + } + } + yield return new object[] { "STRING", new SqlCharVarcharParserResult("STRING", "STRING") }; + } + + public static IEnumerable GenerateDecimalTestData(string typeName) + { + string baseTypeName = SqlDecimalTypeParser.Default.BaseTypeName; + var precisionScales = new[] + { + new { Precision = (int?)null, Scale = (int?)null }, + new { Precision = (int?)1, Scale = (int?)null }, + new { Precision = (int?)1, Scale = (int?)1 }, + new { Precision = (int?)38, Scale = (int?)null }, + new { Precision = (int?)38, Scale = (int?)38 }, + new { Precision = (int?)99, Scale = (int?)null }, + new { Precision = (int?)99, Scale = (int?)99 }, + }; + string[] spaces = ["", " ", "\t"]; + foreach (var precisionScale in precisionScales) + { + foreach (string leadingSpace in spaces) + { + foreach (string trailingSpace in spaces) + { + string clause = precisionScale.Precision == null ? "" + : precisionScale.Scale == null + ? $"({leadingSpace}{precisionScale.Precision}{trailingSpace})" + : $"({leadingSpace}{precisionScale.Precision}{trailingSpace},{leadingSpace}{precisionScale.Scale}{trailingSpace})"; + string testTypeName = $"{leadingSpace}{typeName}{clause}{trailingSpace}"; + SqlDecimalParserResult expectedResult = new(testTypeName, baseTypeName, precisionScale.Precision ?? 10, precisionScale.Scale ?? 0); + yield return new object[] { testTypeName, expectedResult }; + } + } + } + } + } +} diff --git a/csharp/test/Drivers/Apache/Spark/StatementTests.cs b/csharp/test/Drivers/Apache/Spark/StatementTests.cs index e1e44e59f6..aaafc31ba7 100644 --- a/csharp/test/Drivers/Apache/Spark/StatementTests.cs +++ b/csharp/test/Drivers/Apache/Spark/StatementTests.cs @@ -16,91 +16,16 @@ */ using System; -using System.Collections.Generic; -using System.Threading.Tasks; -using Apache.Arrow.Adbc.Drivers.Apache.Spark; -using Apache.Arrow.Adbc.Tests.Xunit; using Xunit; using Xunit.Abstractions; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark { - /// - /// Class for testing the Snowflake ADBC driver connection tests. - /// - /// - /// Tests are ordered to ensure data is created for the other - /// queries to run. - /// - [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] - public class StatementTests : SparkTestBase + public class StatementTests : Common.StatementTests { - private static List DefaultTableTypes => new() { "TABLE", "VIEW" }; - - public StatementTests(ITestOutputHelper? outputHelper) : base(outputHelper) - { - Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); - } - - /// - /// Validates if the SetOption handle valid/invalid data correctly for the PollTime option. - /// - [SkippableTheory] - [InlineData("-1", true)] - [InlineData("zero", true)] - [InlineData("-2147483648", true)] - [InlineData("2147483648", true)] - [InlineData("0")] - [InlineData("1")] - [InlineData("2147483647")] - public void CanSetOptionPollTime(string value, bool throws = false) - { - AdbcStatement statement = NewConnection().CreateStatement(); - if (throws) - { - Assert.Throws(() => statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, value)); - } - else - { - statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, value); - } - } - - /// - /// Validates if the SetOption handle valid/invalid data correctly for the BatchSize option. - /// - [SkippableTheory] - [InlineData("-1", true)] - [InlineData("one", true)] - [InlineData("-2147483648", true)] - [InlineData("2147483648", true)] - [InlineData("0", true)] - [InlineData("1")] - [InlineData("2147483647")] - public void CanSetOptionBatchSize(string value, bool throws = false) - { - AdbcStatement statement = NewConnection().CreateStatement(); - if (throws) - { - Assert.Throws(() => statement.SetOption(SparkStatement.Options.BatchSize, value)); - } - else - { - statement.SetOption(SparkStatement.Options.BatchSize, value); - } - } - - /// - /// Validates if the driver can execute update statements. - /// - [SkippableFact, Order(1)] - public async Task CanInteractUsingSetOptions() + public StatementTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new SparkTestEnvironment.Factory()) { - const string columnName = "INDEX"; - Statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, "100"); - Statement.SetOption(SparkStatement.Options.BatchSize, "10"); - using TemporaryTable temporaryTable = await NewTemporaryTableAsync(Statement, $"{columnName} INT"); - await ValidateInsertSelectDeleteSingleValueAsync(temporaryTable.TableName, columnName, 1); } } } diff --git a/csharp/test/Drivers/Apache/Spark/StringValueTests.cs b/csharp/test/Drivers/Apache/Spark/StringValueTests.cs index 9d2f708d0b..ef4e6f1fb4 100644 --- a/csharp/test/Drivers/Apache/Spark/StringValueTests.cs +++ b/csharp/test/Drivers/Apache/Spark/StringValueTests.cs @@ -15,115 +15,52 @@ * limitations under the License. */ -using System; -using System.Collections.Generic; -using System.Globalization; using System.Threading.Tasks; -using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; using Xunit; using Xunit.Abstractions; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark { - // TODO: When supported, use prepared statements instead of SQL string literals - // Which will better test how the driver handles values sent/received - - /// - /// Validates that specific string and character values can be inserted, retrieved and targeted correctly - /// - public class StringValueTests : SparkTestBase + public class StringValueTests(ITestOutputHelper output) + : Common.StringValueTests(output, new SparkTestEnvironment.Factory()) { - public StringValueTests(ITestOutputHelper output) : base(output) { } - - public static IEnumerable ByteArrayData(int size) + [SkippableTheory] + [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.", SparkServerType.Databricks)] + internal async Task TestStringDataDatabricks(string? value, SparkServerType serverType) { - var rnd = new Random(); - byte[] bytes = new byte[size]; - rnd.NextBytes(bytes); - yield return new object[] { bytes }; + Skip.If(TestEnvironment.ServerType != serverType); + await TestStringData(value); } - /// - /// Validates if driver can send and receive specific String values correctly. - /// [SkippableTheory] - [InlineData(null)] - [InlineData("")] - [InlineData("你好")] - [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.")] - [InlineData(" Leading and trailing spaces ")] - public async Task TestStringData(string? value) + [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.", SparkServerType.Databricks)] + internal async Task TestVarcharDataDatabricks(string? value, SparkServerType serverType) { - string columnName = "STRINGTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "STRING")); - await ValidateInsertSelectDeleteSingleValueAsync( - table.TableName, - columnName, - value, - value != null ? QuoteValue(value) : value); + Skip.If(TestEnvironment.ServerType != serverType); + await TestVarcharData(value); } - /// - /// Validates if driver can send and receive specific VARCHAR values correctly. - /// [SkippableTheory] - [InlineData(null)] - [InlineData("")] - [InlineData("你好")] - [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.")] - [InlineData(" Leading and trailing spaces ")] - public async Task TestVarcharData(string? value) + [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.", SparkServerType.Databricks)] + internal async Task TestCharDataDatabricks(string? value, SparkServerType serverType) { - string columnName = "VARCHARTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "VARCHAR(100)")); - await ValidateInsertSelectDeleteSingleValueAsync( - table.TableName, - columnName, - value, - value != null ? QuoteValue(value) : value); + Skip.If(TestEnvironment.ServerType != serverType); + await TestCharData(value); } - /// - /// Validates if driver can send and receive specific VARCHAR values correctly. - /// - [SkippableTheory] - [InlineData(null)] - [InlineData("")] - [InlineData("你好")] - [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.")] - [InlineData(" Leading and trailing spaces ")] - public async Task TestCharData(string? value) + protected override async Task TestVarcharExceptionData(string value, string[] expectedTexts, string? expectedSqlState) { - string columnName = "CHARTYPE"; - int fieldLength = 100; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, $"CHAR({fieldLength})")); - - string? formattedValue = value != null ? QuoteValue(value.PadRight(fieldLength)) : value; - string? paddedValue = value != null ? value.PadRight(fieldLength) : value; - - await InsertSingleValueAsync(table.TableName, columnName, formattedValue); - await SelectAndValidateValuesAsync(table.TableName, columnName, paddedValue, 1, formattedValue); - string whereClause = GetWhereClause(columnName, formattedValue ?? paddedValue); - await DeleteFromTableAsync(table.TableName, whereClause, 1); + Skip.If(TestEnvironment.ServerType == SparkServerType.Databricks); + await base.TestVarcharExceptionData(value, expectedTexts, expectedSqlState); } - /// - /// Validates if driver fails to insert invalid length of VARCHAR value. - /// [SkippableTheory] - [InlineData("String whose length is too long for VARCHAR(10).")] - public async Task TestVarcharExceptionData(string value) + [InlineData("String whose length is too long for VARCHAR(10).", new string[] { "DELTA_EXCEED_CHAR_VARCHAR_LIMIT", "DeltaInvariantViolationException" }, "22001")] + public async Task TestVarcharExceptionDataDatabricks(string value, string[] expectedTexts, string? expectedSqlState) { - string columnName = "VARCHARTYPE"; - using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "VARCHAR(10)")); - AdbcException exception = await Assert.ThrowsAsync(async () => await ValidateInsertSelectDeleteSingleValueAsync( - table.TableName, - columnName, - value, - value != null ? QuoteValue(value) : value)); - AssertContainsAll(new[] { "DELTA_EXCEED_CHAR_VARCHAR_LIMIT", "DeltaInvariantViolationException" }, exception.Message); - Assert.Equal("22001", exception.SqlState); + Skip.IfNot(TestEnvironment.ServerType == SparkServerType.Databricks, $"Server type: {TestEnvironment.ServerType}"); + await base.TestVarcharExceptionData(value, expectedTexts, expectedSqlState); } - } } diff --git a/csharp/test/Drivers/BigQuery/Apache.Arrow.Adbc.Tests.Drivers.BigQuery.csproj b/csharp/test/Drivers/BigQuery/Apache.Arrow.Adbc.Tests.Drivers.BigQuery.csproj index ee99a4ac85..04f2c92bc0 100644 --- a/csharp/test/Drivers/BigQuery/Apache.Arrow.Adbc.Tests.Drivers.BigQuery.csproj +++ b/csharp/test/Drivers/BigQuery/Apache.Arrow.Adbc.Tests.Drivers.BigQuery.csproj @@ -4,13 +4,14 @@ net8.0 - - - + + + + all runtime; build; native; contentfiles; analyzers; buildtransitive - + diff --git a/csharp/test/Drivers/BigQuery/BigQueryData.cs b/csharp/test/Drivers/BigQuery/BigQueryData.cs index 6be130ee74..d5959f837a 100644 --- a/csharp/test/Drivers/BigQuery/BigQueryData.cs +++ b/csharp/test/Drivers/BigQuery/BigQueryData.cs @@ -42,9 +42,11 @@ public static SampleDataBuilder GetSampleData() SampleDataBuilder sampleDataBuilder = new SampleDataBuilder(); + // standard values sampleDataBuilder.Samples.Add( - new SampleData() { + new SampleData() + { Query = "SELECT " + "CAST(1 as INT64) as id, " + "CAST(1.23 as FLOAT64) as number, " + @@ -63,25 +65,25 @@ public static SampleDataBuilder GetSampleData() "PARSE_JSON('{\"name\":\"Jane Doe\",\"age\":29}') as json", ExpectedValues = new List() { - new ColumnNetTypeArrowTypeValue("id", typeof(long), typeof(Int64Type), 1L), - new ColumnNetTypeArrowTypeValue("number", typeof(double), typeof(DoubleType), 1.23d), - new ColumnNetTypeArrowTypeValue("decimal", typeof(SqlDecimal), typeof(Decimal128Type), SqlDecimal.Parse("4.56")), - new ColumnNetTypeArrowTypeValue("big_decimal", typeof(string), typeof(StringType), "7.89000000000000000000000000000000000000"), - new ColumnNetTypeArrowTypeValue("is_active", typeof(bool), typeof(BooleanType), true), - new ColumnNetTypeArrowTypeValue("name", typeof(string), typeof(StringType), "John Doe"), - new ColumnNetTypeArrowTypeValue("data", typeof(byte[]), typeof(BinaryType), UTF8Encoding.UTF8.GetBytes("abc123")), - new ColumnNetTypeArrowTypeValue("date", typeof(DateTime), typeof(Date64Type), new DateTime(2023, 9, 8)), -#if NET6_0_OR_GREATER - new ColumnNetTypeArrowTypeValue("time", typeof(TimeOnly), typeof(Time64Type), new TimeOnly(12, 34, 56)), //'12:34:56' -#else - new ColumnNetTypeArrowTypeValue("time", typeof(TimeSpan), typeof(Time64Type), new TimeSpan(12, 34, 56)), -#endif - new ColumnNetTypeArrowTypeValue("datetime", typeof(DateTimeOffset), typeof(TimestampType), new DateTimeOffset(new DateTime(2023, 9, 8, 12, 34, 56), TimeSpan.Zero)), - new ColumnNetTypeArrowTypeValue("timestamp", typeof(DateTimeOffset), typeof(TimestampType), new DateTimeOffset(new DateTime(2023, 9, 8, 12, 34, 56), TimeSpan.Zero)), - new ColumnNetTypeArrowTypeValue("point", typeof(string), typeof(StringType), "POINT(1 2)"), - new ColumnNetTypeArrowTypeValue("numbers", typeof(Int64Array), typeof(ListType), numbersArray), - new ColumnNetTypeArrowTypeValue("person", typeof(string), typeof(StringType), "{\"name\":\"John Doe\",\"age\":30}"), - new ColumnNetTypeArrowTypeValue("json", typeof(string), typeof(StringType), "{\"age\":29,\"name\":\"Jane Doe\"}") + new ColumnNetTypeArrowTypeValue("id", typeof(long), typeof(Int64Type), 1L), + new ColumnNetTypeArrowTypeValue("number", typeof(double), typeof(DoubleType), 1.23d), + new ColumnNetTypeArrowTypeValue("decimal", typeof(SqlDecimal), typeof(Decimal128Type), SqlDecimal.Parse("4.56")), + new ColumnNetTypeArrowTypeValue("big_decimal", typeof(string), typeof(StringType), "7.89000000000000000000000000000000000000"), + new ColumnNetTypeArrowTypeValue("is_active", typeof(bool), typeof(BooleanType), true), + new ColumnNetTypeArrowTypeValue("name", typeof(string), typeof(StringType), "John Doe"), + new ColumnNetTypeArrowTypeValue("data", typeof(byte[]), typeof(BinaryType), UTF8Encoding.UTF8.GetBytes("abc123")), + new ColumnNetTypeArrowTypeValue("date", typeof(DateTime), typeof(Date64Type), new DateTime(2023, 9, 8)), + #if NET6_0_OR_GREATER + new ColumnNetTypeArrowTypeValue("time", typeof(TimeOnly), typeof(Time64Type), new TimeOnly(12, 34, 56)), //'12:34:56' + #else + new ColumnNetTypeArrowTypeValue("time", typeof(TimeSpan), typeof(Time64Type), new TimeSpan(12, 34, 56)), + #endif + new ColumnNetTypeArrowTypeValue("datetime", typeof(DateTimeOffset), typeof(TimestampType), new DateTimeOffset(new DateTime(2023, 9, 8, 12, 34, 56), TimeSpan.Zero)), + new ColumnNetTypeArrowTypeValue("timestamp", typeof(DateTimeOffset), typeof(TimestampType), new DateTimeOffset(new DateTime(2023, 9, 8, 12, 34, 56), TimeSpan.Zero)), + new ColumnNetTypeArrowTypeValue("point", typeof(string), typeof(StringType), "POINT(1 2)"), + new ColumnNetTypeArrowTypeValue("numbers", typeof(Int64Array), typeof(ListType), numbersArray), + new ColumnNetTypeArrowTypeValue("person", typeof(string), typeof(StringType), "{\"name\":\"John Doe\",\"age\":30}"), + new ColumnNetTypeArrowTypeValue("json", typeof(string), typeof(StringType), "{\"age\":29,\"name\":\"Jane Doe\"}") } }); @@ -125,82 +127,83 @@ public static SampleDataBuilder GetSampleData() "STRUCT(CAST(NULL as STRING) as name, CAST(NULL as INT64) as age) as person", ExpectedValues = new List() { - new ColumnNetTypeArrowTypeValue("id", typeof(long), typeof(Int64Type), null), - new ColumnNetTypeArrowTypeValue("number", typeof(double), typeof(DoubleType), null), - new ColumnNetTypeArrowTypeValue("decimal", typeof(SqlDecimal), typeof(Decimal128Type), null), - new ColumnNetTypeArrowTypeValue("big_decimal", typeof(string), typeof(StringType), null), - new ColumnNetTypeArrowTypeValue("is_active", typeof(bool), typeof(BooleanType), null), - new ColumnNetTypeArrowTypeValue("name", typeof(string), typeof(StringType), null), - new ColumnNetTypeArrowTypeValue("data", typeof(byte[]), typeof(BinaryType), null), - new ColumnNetTypeArrowTypeValue("date", typeof(DateTime), typeof(Date64Type), null), -#if NET6_0_OR_GREATER - new ColumnNetTypeArrowTypeValue("time", typeof(TimeOnly), typeof(Time64Type), null), -#else - new ColumnNetTypeArrowTypeValue("time", typeof(TimeSpan), typeof(Time64Type), null), -#endif - new ColumnNetTypeArrowTypeValue("datetime", typeof(DateTimeOffset), typeof(TimestampType), null), - new ColumnNetTypeArrowTypeValue("timestamp", typeof(DateTimeOffset), typeof(TimestampType), null), - new ColumnNetTypeArrowTypeValue("point", typeof(string), typeof(StringType), null), - new ColumnNetTypeArrowTypeValue("numbers", typeof(Int64Array), typeof(ListType), emptyNumbersArray), - new ColumnNetTypeArrowTypeValue("person", typeof(string), typeof(StringType), "{\"name\":null,\"age\":null}") + new ColumnNetTypeArrowTypeValue("id", typeof(long), typeof(Int64Type), null), + new ColumnNetTypeArrowTypeValue("number", typeof(double), typeof(DoubleType), null), + new ColumnNetTypeArrowTypeValue("decimal", typeof(SqlDecimal), typeof(Decimal128Type), null), + new ColumnNetTypeArrowTypeValue("big_decimal", typeof(string), typeof(StringType), null), + new ColumnNetTypeArrowTypeValue("is_active", typeof(bool), typeof(BooleanType), null), + new ColumnNetTypeArrowTypeValue("name", typeof(string), typeof(StringType), null), + new ColumnNetTypeArrowTypeValue("data", typeof(byte[]), typeof(BinaryType), null), + new ColumnNetTypeArrowTypeValue("date", typeof(DateTime), typeof(Date64Type), null), + #if NET6_0_OR_GREATER + new ColumnNetTypeArrowTypeValue("time", typeof(TimeOnly), typeof(Time64Type), null), + #else + new ColumnNetTypeArrowTypeValue("time", typeof(TimeSpan), typeof(Time64Type), null), + #endif + new ColumnNetTypeArrowTypeValue("datetime", typeof(DateTimeOffset), typeof(TimestampType), null), + new ColumnNetTypeArrowTypeValue("timestamp", typeof(DateTimeOffset), typeof(TimestampType), null), + new ColumnNetTypeArrowTypeValue("point", typeof(string), typeof(StringType), null), + new ColumnNetTypeArrowTypeValue("numbers", typeof(Int64Array), typeof(ListType), emptyNumbersArray), + new ColumnNetTypeArrowTypeValue("person", typeof(string), typeof(StringType), "{\"name\":null,\"age\":null}") } }); + // complex struct sampleDataBuilder.Samples.Add( - new SampleData() - { - Query = "SELECT " + - "STRUCT(" + - "\"Iron Man\" as name," + - "\"Avengers\" as team," + - "[\"Genius\", \"Billionaire\", \"Playboy\", \"Philanthropist\"] as powers," + - "[" + - " STRUCT(" + - " \"Captain America\" as name, " + - " \"Avengers\" as team, " + - " [\"Super Soldier Serum\", \"Vibranium Shield\"] as powers, " + - " [" + - " STRUCT(" + - " \"Thanos\" as name, " + - " \"Black Order\" as team, " + - " [\"Infinity Gauntlet\", \"Super Strength\", \"Teleportation\"] as powers, " + - " [" + - " STRUCT(" + - " \"Loki\" as name, " + - " \"Asgard\" as team, " + - " [\"Magic\", \"Shapeshifting\", \"Trickery\"] as powers " + - " )" + - " ] as allies" + - " )" + - " ] as enemies" + - " )," + - " STRUCT(" + - " \"Spider-Man\" as name, " + - " \"Avengers\" as team, " + - " [\"Spider-Sense\", \"Web-Shooting\", \"Wall-Crawling\"] as powers, " + - " [" + - " STRUCT(" + - " \"Green Goblin\" as name, " + - " \"Sinister Six\" as team, " + - " [\"Glider\", \"Pumpkin Bombs\", \"Super Strength\"] as powers, " + - " [" + - " STRUCT(" + - " \"Doctor Octopus\" as name, " + - " \"Sinister Six\" as team, " + - " [\"Mechanical Arms\", \"Genius\", \"Madness\"] as powers " + - " )" + - " ] as allies" + - " )" + - " ] as enemies" + - " )" + - " ] as friends" + - ") as iron_man", - ExpectedValues = new List() + new SampleData() + { + Query = "SELECT " + + "STRUCT(" + + "\"Iron Man\" as name," + + "\"Avengers\" as team," + + "[\"Genius\", \"Billionaire\", \"Playboy\", \"Philanthropist\"] as powers," + + "[" + + " STRUCT(" + + " \"Captain America\" as name, " + + " \"Avengers\" as team, " + + " [\"Super Soldier Serum\", \"Vibranium Shield\"] as powers, " + + " [" + + " STRUCT(" + + " \"Thanos\" as name, " + + " \"Black Order\" as team, " + + " [\"Infinity Gauntlet\", \"Super Strength\", \"Teleportation\"] as powers, " + + " [" + + " STRUCT(" + + " \"Loki\" as name, " + + " \"Asgard\" as team, " + + " [\"Magic\", \"Shapeshifting\", \"Trickery\"] as powers " + + " )" + + " ] as allies" + + " )" + + " ] as enemies" + + " )," + + " STRUCT(" + + " \"Spider-Man\" as name, " + + " \"Avengers\" as team, " + + " [\"Spider-Sense\", \"Web-Shooting\", \"Wall-Crawling\"] as powers, " + + " [" + + " STRUCT(" + + " \"Green Goblin\" as name, " + + " \"Sinister Six\" as team, " + + " [\"Glider\", \"Pumpkin Bombs\", \"Super Strength\"] as powers, " + + " [" + + " STRUCT(" + + " \"Doctor Octopus\" as name, " + + " \"Sinister Six\" as team, " + + " [\"Mechanical Arms\", \"Genius\", \"Madness\"] as powers " + + " )" + + " ] as allies" + + " )" + + " ] as enemies" + + " )" + + " ] as friends" + + ") as iron_man", + ExpectedValues = new List() { - new ColumnNetTypeArrowTypeValue("iron_man", typeof(string), typeof(StringType), "{\"name\":\"Iron Man\",\"team\":\"Avengers\",\"powers\":[\"Genius\",\"Billionaire\",\"Playboy\",\"Philanthropist\"],\"friends\":[{\"name\":\"Captain America\",\"team\":\"Avengers\",\"powers\":[\"Super Soldier Serum\",\"Vibranium Shield\"],\"enemies\":{\"name\":\"Thanos\",\"team\":\"Black Order\",\"powers\":[\"Infinity Gauntlet\",\"Super Strength\",\"Teleportation\"],\"allies\":{\"name\":\"Loki\",\"team\":\"Asgard\",\"powers\":[\"Magic\",\"Shapeshifting\",\"Trickery\"]}}},{\"name\":\"Spider-Man\",\"team\":\"Avengers\",\"powers\":[\"Spider-Sense\",\"Web-Shooting\",\"Wall-Crawling\"],\"enemies\":{\"name\":\"Green Goblin\",\"team\":\"Sinister Six\",\"powers\":[\"Glider\",\"Pumpkin Bombs\",\"Super Strength\"],\"allies\":{\"name\":\"Doctor Octopus\",\"team\":\"Sinister Six\",\"powers\":[\"Mechanical Arms\",\"Genius\",\"Madness\"]}}}]}") + new ColumnNetTypeArrowTypeValue("iron_man", typeof(string), typeof(StringType), "{\"name\":\"Iron Man\",\"team\":\"Avengers\",\"powers\":[\"Genius\",\"Billionaire\",\"Playboy\",\"Philanthropist\"],\"friends\":[{\"name\":\"Captain America\",\"team\":\"Avengers\",\"powers\":[\"Super Soldier Serum\",\"Vibranium Shield\"],\"enemies\":[{\"name\":\"Thanos\",\"team\":\"Black Order\",\"powers\":[\"Infinity Gauntlet\",\"Super Strength\",\"Teleportation\"],\"allies\":[{\"name\":\"Loki\",\"team\":\"Asgard\",\"powers\":[\"Magic\",\"Shapeshifting\",\"Trickery\"]}]}]},{\"name\":\"Spider-Man\",\"team\":\"Avengers\",\"powers\":[\"Spider-Sense\",\"Web-Shooting\",\"Wall-Crawling\"],\"enemies\":[{\"name\":\"Green Goblin\",\"team\":\"Sinister Six\",\"powers\":[\"Glider\",\"Pumpkin Bombs\",\"Super Strength\"],\"allies\":[{\"name\":\"Doctor Octopus\",\"team\":\"Sinister Six\",\"powers\":[\"Mechanical Arms\",\"Genius\",\"Madness\"]}]}]}]}") } - }); + }); return sampleDataBuilder; } diff --git a/csharp/test/Drivers/BigQuery/BigQueryTestConfiguration.cs b/csharp/test/Drivers/BigQuery/BigQueryTestConfiguration.cs index 20390deb3a..ffcae7cc05 100644 --- a/csharp/test/Drivers/BigQuery/BigQueryTestConfiguration.cs +++ b/csharp/test/Drivers/BigQuery/BigQueryTestConfiguration.cs @@ -56,5 +56,11 @@ public BigQueryTestConfiguration() [JsonPropertyName("includeTableConstraints")] public bool IncludeTableConstraints { get; set; } + + [JsonPropertyName("timeoutMinutes")] + public int? TimeoutMinutes { get; set; } + + [JsonPropertyName("maxStreamCount")] + public int? MaxStreamCount { get; set; } } } diff --git a/csharp/test/Drivers/BigQuery/BigQueryTestingUtils.cs b/csharp/test/Drivers/BigQuery/BigQueryTestingUtils.cs index 654cdea2e5..2bc6227bab 100644 --- a/csharp/test/Drivers/BigQuery/BigQueryTestingUtils.cs +++ b/csharp/test/Drivers/BigQuery/BigQueryTestingUtils.cs @@ -86,6 +86,16 @@ internal static Dictionary GetBigQueryParameters(BigQueryTestCon parameters.Add(BigQueryParameters.LargeResultsDestinationTable, testConfiguration.LargeResultsDestinationTable); } + if (testConfiguration.TimeoutMinutes.HasValue) + { + parameters.Add(BigQueryParameters.GetQueryResultsOptionsTimeoutMinutes, testConfiguration.TimeoutMinutes.Value.ToString()); + } + + if (testConfiguration.MaxStreamCount.HasValue) + { + parameters.Add(BigQueryParameters.MaxFetchConcurrency, testConfiguration.MaxStreamCount.Value.ToString()); + } + return parameters; } diff --git a/csharp/test/Drivers/BigQuery/DriverTests.cs b/csharp/test/Drivers/BigQuery/DriverTests.cs index 047c0e0fcb..93ecdc786d 100644 --- a/csharp/test/Drivers/BigQuery/DriverTests.cs +++ b/csharp/test/Drivers/BigQuery/DriverTests.cs @@ -120,7 +120,7 @@ public void CanGetObjects() RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, catalogName, schemaName); + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); List? columns = catalogs .Select(s => s.DbSchemas) diff --git a/csharp/test/Drivers/BigQuery/Resources/bigqueryconfig.json b/csharp/test/Drivers/BigQuery/Resources/bigqueryconfig.json index 888ae7eb30..960bcb8683 100644 --- a/csharp/test/Drivers/BigQuery/Resources/bigqueryconfig.json +++ b/csharp/test/Drivers/BigQuery/Resources/bigqueryconfig.json @@ -3,11 +3,12 @@ "clientId": "", "clientSecret": "", "refreshToken": "", + "maxStreamCount": 1, "metadata": { "catalog": "", "schema": "", "table": "", - "expectedColumnCount": 0 + "expectedColumnCount": 0 }, "query": "", "expectedResults": 0 diff --git a/csharp/test/Drivers/BigQuery/readme.md b/csharp/test/Drivers/BigQuery/readme.md index 0770bb6d39..26f177da4a 100644 --- a/csharp/test/Drivers/BigQuery/readme.md +++ b/csharp/test/Drivers/BigQuery/readme.md @@ -38,6 +38,11 @@ The following values can be setup in the configuration - **expectedColumnCount** - Used by metadata tests to validate the number of columns that are returned. - **query** - The query to use. - **expectedResults** - The expected number of results from the query. +- **timeoutMinutes** - The timeout (in minutes). +- **maxStreamCount** - The max stream count. +- **includeTableConstraints** - Whether to include table constraints in the GetObjects query. +- **largeResultsDestinationTable** - Sets the [DestinationTable](https://cloud.google.com/dotnet/docs/reference/Google.Cloud.BigQuery.V2/latest/Google.Cloud.BigQuery.V2.QueryOptions#Google_Cloud_BigQuery_V2_QueryOptions_DestinationTable) value of the QueryOptions if configured. Expects the format to be `{projectId}.{datasetId}.{tableId}` to set the corresponding values in the [TableReference](https://github.com/googleapis/google-api-dotnet-client/blob/6c415c73788b848711e47c6dd33c2f93c76faf97/Src/Generated/Google.Apis.Bigquery.v2/Google.Apis.Bigquery.v2.cs#L9348) class. +- **allowLargeResults** - Whether to allow large results . ## Data This project contains a SQL script to generate BigQuery data in the `resources/BigQueryData.sql` file. This can be used to populate a table in your BigQuery instance with data. diff --git a/csharp/test/Drivers/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.FlightSql.csproj b/csharp/test/Drivers/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.FlightSql.csproj index b6e00e11a0..265b67827a 100644 --- a/csharp/test/Drivers/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.FlightSql.csproj +++ b/csharp/test/Drivers/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.FlightSql.csproj @@ -5,13 +5,13 @@ False - - - + + + all runtime; build; native; contentfiles; analyzers; buildtransitive - + diff --git a/csharp/test/Drivers/Interop/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql.csproj b/csharp/test/Drivers/Interop/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql.csproj new file mode 100644 index 0000000000..848ca9ea64 --- /dev/null +++ b/csharp/test/Drivers/Interop/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql.csproj @@ -0,0 +1,20 @@ + + + net8.0;net472 + net8.0 + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + diff --git a/csharp/test/Drivers/Interop/FlightSql/ClientTests.cs b/csharp/test/Drivers/Interop/FlightSql/ClientTests.cs new file mode 100644 index 0000000000..7dc8c44097 --- /dev/null +++ b/csharp/test/Drivers/Interop/FlightSql/ClientTests.cs @@ -0,0 +1,241 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Data.SqlTypes; +using Apache.Arrow.Adbc.Client; +using Apache.Arrow.Adbc.Tests.Xunit; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql +{ + /// + /// Class for testing the ADBC Client using the Flight SQL ADBC driver. + /// + /// + /// Tests are ordered to ensure data is created + /// for the other queries to run. + /// + [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] + public class ClientTests + { + readonly FlightSqlTestConfiguration _testConfiguration; + readonly List _environments; + readonly Dictionary _configuredDrivers = new Dictionary(); + readonly ITestOutputHelper _outputHelper; + + public ClientTests(ITestOutputHelper outputHelper) + { + Skip.IfNot(Utils.CanExecuteTestConfig(FlightSqlTestingUtils.FLIGHTSQL_INTEROP_TEST_CONFIG_VARIABLE)); + _testConfiguration = FlightSqlTestingUtils.LoadFlightSqlTestConfiguration(FlightSqlTestingUtils.FLIGHTSQL_INTEROP_TEST_CONFIG_VARIABLE); + _environments = FlightSqlTestingUtils.GetTestEnvironments(_testConfiguration); + _outputHelper = outputHelper; + } + + /// + /// Validates if the client execute updates. + /// + [SkippableFact, Order(1)] + public void CanClientExecuteUpdate() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + if (environment.SupportsWriteUpdate) + { + using (Adbc.Client.AdbcConnection adbcConnection = GetFlightSqlAdbcConnectionUsingConnectionString(environment, _testConfiguration)) + { + string[] queries = FlightSqlTestingUtils.GetQueries(environment); + + List expectedResults = new List() { -1, 1, 1 }; + + Tests.ClientTests.CanClientExecuteUpdate(adbcConnection, environment, queries, expectedResults); + } + } + else + { + _outputHelper.WriteLine("WriteUpdate is not supported in the [" + environment.Name + "] environment"); + } + } + } + + /// + /// Validates if the client execute updates using the reader. + /// + [SkippableFact, Order(2)] + public void CanClientExecuteUpdateUsingExecuteReader() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + if (environment.SupportsWriteUpdate) + { + using (Adbc.Client.AdbcConnection adbcConnection = GetFlightSqlAdbcConnectionUsingConnectionString(environment, _testConfiguration)) + { + adbcConnection.Open(); + + string[] queries = FlightSqlTestingUtils.GetQueries(environment); + + List expectedResults = new List() { $"Table {environment.Metadata.Table} successfully created.", new SqlDecimal(1L), new SqlDecimal(1L) }; + + for (int i = 0; i < queries.Length; i++) + { + string query = queries[i]; + AdbcCommand adbcCommand = adbcConnection.CreateCommand(); + adbcCommand.CommandText = query; + + AdbcDataReader reader = adbcCommand.ExecuteReader(CommandBehavior.Default); + + if (reader.Read()) + { + Assert.True(expectedResults[i].Equals(reader.GetValue(0)), $"The expected affected rows do not match the actual affected rows at position {i} in the [" + environment.Name + "] environment"); + } + else + { + Assert.Fail("Could not read the records in the [" + environment.Name + "] environment"); + } + } + } + } + else + { + _outputHelper.WriteLine("WriteUpdate is not supported in the [" + environment.Name + "] environment"); + } + } + } + + /// + /// Validates if the client can get the schema. + /// + [SkippableFact, Order(3)] + public void CanClientGetSchema() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + using (Adbc.Client.AdbcConnection adbcConnection = GetFlightSqlAdbcConnectionUsingConnectionString(environment, _testConfiguration)) + { + Tests.ClientTests.CanClientGetSchema(adbcConnection, environment, environmentName: environment.Name); + } + } + } + + /// + /// Validates if the client can connect to a live server + /// and parse the results. + /// + [SkippableFact, Order(4)] + public void CanClientExecuteQuery() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + using (Adbc.Client.AdbcConnection adbcConnection = GetFlightSqlAdbcConnectionUsingConnectionString(environment, _testConfiguration)) + { + Tests.ClientTests.CanClientExecuteQuery(adbcConnection, environment, additionalCommandOptionsSetter: null, environmentName: environment.Name); + } + } + } + + // + /// Validates if the client can connect to a live server + /// and parse the results. + /// + [SkippableFact, Order(4)] + public void CanClientExecuteQueryWithNoResults() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + environment.Query = "SELECT * WHERE 0=1"; + environment.ExpectedResultsCount = 0; + + using (Adbc.Client.AdbcConnection adbcConnection = GetFlightSqlAdbcConnectionUsingConnectionString(environment, _testConfiguration)) + { + Tests.ClientTests.CanClientExecuteQuery(adbcConnection, environment, additionalCommandOptionsSetter: null, environment.Name); + } + } + } + + /// + /// Validates if the client is retrieving and converting values + /// to the expected types. + /// + [SkippableFact, Order(6)] + public void VerifyTypesAndValues() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + using (Adbc.Client.AdbcConnection adbcConnection = GetFlightSqlAdbcConnectionUsingConnectionString(environment, _testConfiguration)) + { + SampleDataBuilder sampleDataBuilder = FlightSqlData.GetSampleData(environment.EnvironmentType); + + Tests.ClientTests.VerifyTypesAndValues(adbcConnection, sampleDataBuilder, environment.Name); + } + } + } + + private Adbc.Client.AdbcConnection GetFlightSqlAdbcConnectionUsingConnectionString(FlightSqlTestEnvironment environment, FlightSqlTestConfiguration testConfiguration, string? authType = null) + { + // see https://arrow.apache.org/adbc/main/driver/flight_sql.html + DbConnectionStringBuilder builder = new DbConnectionStringBuilder(true); + if (!string.IsNullOrEmpty(environment.Uri)) + { + builder[FlightSqlParameters.Uri] = environment.Uri; + } + + foreach (string key in environment.RPCCallHeaders.Keys) + { + builder[FlightSqlParameters.OptionRPCCallHeaderPrefix + key] = environment.RPCCallHeaders[key]; + } + + if (!string.IsNullOrEmpty(environment.AuthorizationHeader)) + { + builder[FlightSqlParameters.OptionAuthorizationHeader] = environment.AuthorizationHeader; + } + else + { + if (!string.IsNullOrEmpty(environment.Username) && !string.IsNullOrEmpty(environment.Password)) + { + builder[FlightSqlParameters.Username] = environment.Username; + builder[FlightSqlParameters.Password] = environment.Password; + } + } + + if (!string.IsNullOrEmpty(environment.TimeoutQuery)) + builder[FlightSqlParameters.OptionTimeoutQuery] = environment.TimeoutQuery; + + if (!string.IsNullOrEmpty(environment.TimeoutFetch)) + builder[FlightSqlParameters.OptionTimeoutFetch] = environment.TimeoutFetch; + + if (!string.IsNullOrEmpty(environment.TimeoutUpdate)) + builder[FlightSqlParameters.OptionTimeoutUpdate] = environment.TimeoutUpdate; + + if (environment.SSLSkipVerify) + builder[FlightSqlParameters.OptionSSLSkipVerify] = Convert.ToString(environment.SSLSkipVerify).ToLowerInvariant(); + + if (!string.IsNullOrEmpty(environment.Authority)) + builder[FlightSqlParameters.OptionAuthority] = environment.Authority; + + AdbcDriver driver = FlightSqlTestingUtils.GetFlightSqlAdbcDriver(testConfiguration); + + return new Adbc.Client.AdbcConnection(builder.ConnectionString) + { + AdbcDriver = driver + }; + } + } +} diff --git a/csharp/test/Drivers/Interop/FlightSql/DriverTests.cs b/csharp/test/Drivers/Interop/FlightSql/DriverTests.cs new file mode 100644 index 0000000000..bc3ea53ef8 --- /dev/null +++ b/csharp/test/Drivers/Interop/FlightSql/DriverTests.cs @@ -0,0 +1,456 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Apache.Arrow.Adbc.Tests.Metadata; +using Apache.Arrow.Adbc.Tests.Xunit; +using Apache.Arrow.Ipc; +using Apache.Arrow.Types; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql +{ + /// + /// Class for testing the Flight SQL ADBC driver connection tests. + /// + /// + /// Tests are ordered to ensure data is created for the other + /// queries to run. + /// + [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] + public class DriverTests : IDisposable + { + readonly FlightSqlTestConfiguration _testConfiguration; + readonly List _environments; + readonly Dictionary _configuredConnections = new Dictionary(); + readonly ITestOutputHelper _outputHelper; + + private List GetPatterns(string? namePattern, bool caseSenstive) + { + List patterns = new List(); + + string name = namePattern!; + patterns.Add(name); + patterns.Add($"{GetPartialNameForPatternMatch(name)}%"); + patterns.Add($"_{GetNameWithoutFirstChatacter(name)}"); + + if (!caseSenstive) + { + patterns.Add($"{GetPartialNameForPatternMatch(name).ToLower()}%"); + patterns.Add($"{GetPartialNameForPatternMatch(name).ToUpper()}%"); + patterns.Add($"_{GetNameWithoutFirstChatacter(name).ToLower()}"); + patterns.Add($"_{GetNameWithoutFirstChatacter(name).ToUpper()}"); + } + + return patterns; + } + + public DriverTests(ITestOutputHelper outputHelper) + { + Skip.IfNot(Utils.CanExecuteTestConfig(FlightSqlTestingUtils.FLIGHTSQL_INTEROP_TEST_CONFIG_VARIABLE)); + _testConfiguration = FlightSqlTestingUtils.LoadFlightSqlTestConfiguration(); + _environments = FlightSqlTestingUtils.GetTestEnvironments(_testConfiguration); + _outputHelper = outputHelper; + + foreach (FlightSqlTestEnvironment environment in _environments) + { + Dictionary parameters = new Dictionary(); + Dictionary options = new Dictionary(); + AdbcDriver driver = FlightSqlTestingUtils.GetAdbcDriver(_testConfiguration, environment, out parameters); + AdbcDatabase database = driver.Open(parameters); + AdbcConnection connection = database.Connect(options); + + _configuredConnections.Add(environment.Name!, connection); + } + } + + /// + /// Validates if the driver can connect to a live server and + /// parse the results. + /// + /// + /// Tests are ordered to ensure data is created + /// for the other queries to run. + /// + [SkippableFact, Order(1)] + public void CanExecuteUpdate() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + // Dremio doesn't have acceptPut implemented by design. + if (environment.SupportsWriteUpdate) + { + string[] queries = FlightSqlTestingUtils.GetQueries(environment); + + List expectedResults = new List() { -1, 1, 1 }; + + for (int i = 0; i < queries.Length; i++) + { + string query = queries[i]; + UpdateResult updateResult = ExecuteUpdateStatement(environment, query); + Assert.Equal(expectedResults[i], updateResult.AffectedRows); + } + } + else + { + _outputHelper.WriteLine("WriteUpdate is not supported in the [" + environment.Name + "] environment"); + } + } + } + + /// + /// Validates if the driver can call GetInfo. + /// + [SkippableFact, Order(2)] + public void CanGetInfo() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + using IArrowArrayStream stream = GetAdbcConnection(environment.Name).GetInfo(new List() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.VendorName }); + using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name"); + + List expectedValues = new List() { "DriverName", "DriverVersion", "VendorName" }; + + for (int i = 0; i < infoNameArray.Length; i++) + { + uint? uintValue = infoNameArray.GetValue(i); + + if (uintValue.HasValue) + { + AdbcInfoCode value = (AdbcInfoCode)uintValue; + DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value"); + + Assert.Contains(value.ToString(), expectedValues); + + StringArray stringArray = (StringArray)valueArray.Fields[0]; + _outputHelper.WriteLine($"{value}={stringArray.GetString(i)}"); + } + } + } + } + + /// + /// Validates if the driver can call GetObjects with GetObjectsDepth as Catalogs and CatalogName passed as a pattern. + /// + [SkippableFact, Order(3)] + public void CanGetObjectsCatalogs() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + // Dremio doesn't use catalogs + if (environment.SupportsCatalogs) + { + string databaseName = environment.Metadata.Catalog; + foreach (string catalogPattern in GetPatterns(databaseName, environment.CaseSensitive)) + { + string schemaName = environment.Metadata.Schema; + + using IArrowArrayStream stream = GetAdbcConnection(environment.Name).GetObjects( + depth: AdbcConnection.GetObjectsDepth.Catalogs, + catalogPattern: catalogPattern, + dbSchemaPattern: null, + tableNamePattern: null, + tableTypes: environment.TableTypes, + columnNamePattern: null); + + using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, null); + AdbcCatalog? catalog = catalogs.Where((catalog) => string.Equals(catalog.Name, databaseName)).FirstOrDefault(); + + Assert.True(catalog != null, "catalog should not be null in the [" + environment.Name + "] environment"); + } + } + else + { + _outputHelper.WriteLine("Catalogs are not supported in the [" + environment.Name + "] environment"); + } + } + } + + /// + /// Validates if the driver can call GetObjects with GetObjectsDepth as DbSchemas with DbSchemaName as a pattern. + /// + [SkippableFact, Order(3)] + public void CanGetObjectsDbSchemas() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + string databaseName = environment.Metadata.Catalog; + string schemaName = environment.Metadata.Schema; + + if (schemaName != null) + { + foreach (string dbSchemaPattern in GetPatterns(schemaName, environment.CaseSensitive)) + { + using IArrowArrayStream stream = GetAdbcConnection(environment.Name).GetObjects( + depth: AdbcConnection.GetObjectsDepth.DbSchemas, + catalogPattern: databaseName, + dbSchemaPattern: dbSchemaPattern, + tableNamePattern: null, + tableTypes: environment.TableTypes, + columnNamePattern: null); + + using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); + + List? dbSchemas = catalogs + .Where(c => environment.SupportsCatalogs ? string.Equals(c.Name, databaseName) : true) + .Select(c => c.DbSchemas) + .FirstOrDefault(); + + Assert.True(dbSchemas != null, "dbSchemas should not be null in the [" + environment.Name + "] environment"); + + AdbcDbSchema? dbSchema = dbSchemas + .Where(dbSchema => schemaName == null ? string.Equals(dbSchema.Name, string.Empty) : string.Equals(dbSchema.Name, schemaName)) + .FirstOrDefault(); + + Assert.True(dbSchema != null, "dbSchema should not be null in the [" + environment.Name + "] environment"); + } + } + } + } + + /// + /// Validates if the driver can call GetObjects with GetObjectsDepth as Tables with TableName as a pattern. + /// + [SkippableFact, Order(3)] + public void CanGetObjectsTables() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + string databaseName = environment.Metadata.Catalog; + string schemaName = environment.Metadata.Schema; + string tableName = environment.Metadata.Table; + + foreach (string tableNamePattern in GetPatterns(tableName, environment.CaseSensitive)) + { + using IArrowArrayStream stream = GetAdbcConnection(environment.Name).GetObjects( + depth: AdbcConnection.GetObjectsDepth.Tables, + catalogPattern: databaseName, + dbSchemaPattern: schemaName, + tableNamePattern: tableNamePattern, + tableTypes: environment.TableTypes, + columnNamePattern: null); + + using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); + + List? schemas = catalogs + .Where(c => environment.SupportsCatalogs ? string.Equals(c.Name, databaseName) : true) + .Select(c => c.DbSchemas) + .FirstOrDefault(); + + Assert.True(schemas != null, "schemas should not be null in the [" + environment.Name + "] environment"); + + List? tables = schemas + .Where(s => schemaName == null ? string.Equals(s.Name, string.Empty) : string.Equals(s.Name, schemaName)) + .Select(s => s.Tables) + .FirstOrDefault(); + + Assert.True(tables != null, "schemas should not be null in the [" + environment.Name + "] environment"); + + AdbcTable? table = tables.Where((table) => string.Equals(table.Name, tableName)).FirstOrDefault(); + Assert.True(table != null, $"could not find the table named [{tableName}] from the [{tableNamePattern}] pattern in the [" + environment.Name + "] environment. Is this environment case sensitive?"); + } + } + } + + /// + /// Validates if the driver can call GetObjects for GetObjectsDepth as All. + /// + [SkippableFact, Order(3)] + public void CanGetObjectsAll() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + string? databaseName = environment.Metadata.Catalog; + string? schemaName = environment.Metadata.Schema; + string tableName = environment.Metadata.Table; + string? columnName = null; + + using IArrowArrayStream stream = GetAdbcConnection(environment.Name).GetObjects( + depth: AdbcConnection.GetObjectsDepth.All, + catalogPattern: databaseName, + dbSchemaPattern: schemaName, + tableNamePattern: tableName, + tableTypes: environment.TableTypes, + columnNamePattern: columnName); + + using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); + + List? schemas = catalogs + .Where(c => environment.SupportsCatalogs ? string.Equals(c.Name, databaseName) : true) + .Select(c => c.DbSchemas) + .FirstOrDefault(); + + Assert.True(schemas != null, "schemas should not be null in the [" + environment.Name + "] environment"); + + List? tables; + + if (schemaName == null) + { + tables = schemas + .Where(s => string.Equals(s.Name, string.Empty)) + .Select(s => s.Tables) + .FirstOrDefault(); + } + else + { + tables = schemas + .Where(s => string.Equals(s.Name, schemaName)) + .Select(s => s.Tables) + .FirstOrDefault(); + } + + Assert.True(tables != null, "tables should not be null in the [" + environment.Name + "] environment"); + + AdbcTable? table = tables + .Where(t => string.Equals(t.Name, tableName)) + .FirstOrDefault(); + + Assert.True(table != null, "table should not be null in the [" + environment.Name + "] environment"); + + List? columns = table.Columns; + + Assert.True(columns != null, "Columns cannot be null in the [" + environment.Name + "] environment"); + Assert.Equal(environment.Metadata.ExpectedColumnCount, columns.Count); + } + } + + /// + /// Validates if the driver can call GetTableSchema. + /// + [SkippableFact, Order(4)] + public void CanGetTableSchema() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + string databaseName = environment.Metadata.Catalog; + string schemaName = environment.Metadata.Schema; + string tableName = environment.Metadata.Table; + + Schema schema = GetAdbcConnection(environment.Name).GetTableSchema(databaseName, schemaName, tableName); + + int numberOfFields = schema.FieldsList.Count; + + Assert.True(environment.Metadata.ExpectedColumnCount == numberOfFields, "ExpectedColumnCount not equal in the [" + environment.Name + "] environment"); + } + } + + /// + /// Validates if the driver can call GetTableTypes. + /// + [SkippableFact, Order(5)] + public void CanGetTableTypes() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + using IArrowArrayStream arrowArrayStream = GetAdbcConnection(environment.Name).GetTableTypes(); + + using RecordBatch recordBatch = arrowArrayStream.ReadNextRecordBatchAsync().Result; + + StringArray stringArray = (StringArray)recordBatch.Column("table_type"); + + List known_types = environment.TableTypes; + + int results = 0; + + for (int i = 0; i < stringArray.Length; i++) + { + string value = stringArray.GetString(i); + + if (known_types.Contains(value)) + { + results++; + } + } + + Assert.True(known_types.Count == results, "TableTypes not equal in the [" + environment.Name + "] environment"); + } + } + + /// + /// Validates if the driver can connect to a live server and + /// parse the results. + /// + [SkippableFact, Order(6)] + public void CanExecuteQuery() + { + foreach (FlightSqlTestEnvironment environment in _environments) + { + using AdbcStatement statement = GetAdbcConnection(environment.Name).CreateStatement(); + statement.SqlQuery = environment.Query; + + QueryResult queryResult = statement.ExecuteQuery(); + + Tests.DriverTests.CanExecuteQuery(queryResult, environment.ExpectedResultsCount, environment.Name); + } + } + + private UpdateResult ExecuteUpdateStatement(FlightSqlTestEnvironment environment, string query) + { + using AdbcStatement statement = GetAdbcConnection(environment.Name).CreateStatement(); + statement.SqlQuery = query; + UpdateResult updateResult = statement.ExecuteUpdate(); + return updateResult; + } + + private AdbcConnection GetAdbcConnection(string? environmentName) + { + if (string.IsNullOrEmpty(environmentName)) + { + throw new ArgumentNullException(nameof(environmentName)); + } + + return _configuredConnections[environmentName!]; + } + + private string GetPartialNameForPatternMatch(string name) + { + if (string.IsNullOrEmpty(name) || name.Length == 1) return name; + + return name.Substring(0, name.Length / 2); + } + + private string GetNameWithoutFirstChatacter(string name) + { + if (string.IsNullOrEmpty(name)) return name; + + return name.Substring(1); + } + + public void Dispose() + { + foreach (AdbcConnection configuredConnection in this._configuredConnections.Values) + { + configuredConnection.Dispose(); + } + + _configuredConnections.Clear(); + } + } +} diff --git a/csharp/test/Drivers/Interop/FlightSql/FlightSqlData.cs b/csharp/test/Drivers/Interop/FlightSql/FlightSqlData.cs new file mode 100644 index 0000000000..07b0e24f34 --- /dev/null +++ b/csharp/test/Drivers/Interop/FlightSql/FlightSqlData.cs @@ -0,0 +1,208 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Text; +using Apache.Arrow.Scalars; +using Apache.Arrow.Types; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql +{ + /// + /// Gets the sample data from Flight SQL. + /// + internal class FlightSqlData + { + /// + /// Sample data + /// + /// + /// The type of environment to get the sample data for. + /// + public static SampleDataBuilder GetSampleData( + FlightSqlTestEnvironmentType environmentType + ) + { + switch (environmentType) + { + case FlightSqlTestEnvironmentType.Denodo: + return GetDenodoSampleData(); + case FlightSqlTestEnvironmentType.Dremio: + return GetDremioSampleData(); + case FlightSqlTestEnvironmentType.DuckDB: + return GetDuckDbSampleData(); + case FlightSqlTestEnvironmentType.SQLite: + return GetSQLiteSampleData(); + default: + throw new InvalidOperationException("Unknown environment type."); + } + } + + private static SampleDataBuilder GetDenodoSampleData() + { + return new SampleDataBuilder(); + } + + private static SampleDataBuilder GetDremioSampleData() + { + ListArray.Builder labuilder = new ListArray.Builder(Int32Type.Default); + Int32Array.Builder numbersBuilder = (Int32Array.Builder)labuilder.ValueBuilder; + labuilder.Append(); + numbersBuilder.AppendRange(new List() { 1, 2, 3 }); + + Int32Array numbersArray = numbersBuilder.Build(); + + SampleDataBuilder sampleDataBuilder = new SampleDataBuilder(); + + sampleDataBuilder.Samples.Add( + new SampleData() + { + Query = "SELECT " + + "TRUE AS sample_boolean, " + + "CAST(123 AS INTEGER) AS sample_integer, " + + "CAST(1234567890 AS BIGINT) AS sample_bigint, " + + "CAST(123.45 AS FLOAT) AS sample_float," + + "CAST(12345.6789 AS DOUBLE) AS sample_double, " + + "CAST(12345.67 AS DECIMAL(10, 2)) AS sample_decimal, " + + "'Sample Text' AS sample_varchar, " + + "DATE '2024-01-01' AS sample_date, " + + "TIME '12:34:56' AS sample_time, " + + "TIMESTAMP '2024-01-01 12:34:56' AS sample_timestamp, " + + "ARRAY[1, 2, 3] AS sample_array, " + + "CONVERT_FROM('{\"name\":\"Gnarly\", \"age\":7, \"car\":null}', 'json') as sample_struct", + ExpectedValues = new List() + { + new ColumnNetTypeArrowTypeValue("sample_boolean", typeof(bool), typeof(BooleanType), true), + new ColumnNetTypeArrowTypeValue("sample_integer", typeof(int), typeof(Int32Type), 123), + new ColumnNetTypeArrowTypeValue("sample_bigint", typeof(Int64), typeof(Int64Type), 1234567890L), + new ColumnNetTypeArrowTypeValue("sample_float", typeof(float), typeof(FloatType), 123.45f), + new ColumnNetTypeArrowTypeValue("sample_double", typeof(double), typeof(DoubleType), 12345.6789d), + new ColumnNetTypeArrowTypeValue("sample_decimal", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(12345.67m)), + new ColumnNetTypeArrowTypeValue("sample_varchar", typeof(string), typeof(StringType), "Sample Text"), + new ColumnNetTypeArrowTypeValue("sample_date", typeof(DateTime), typeof(Date64Type), new DateTime(2024, 1, 1)), +#if NET6_0_OR_GREATER + new ColumnNetTypeArrowTypeValue("sample_time", typeof(TimeOnly), typeof(Time32Type), new TimeOnly(12, 34, 56)), +#else + new ColumnNetTypeArrowTypeValue("sample_time", typeof(TimeSpan), typeof(Time32Type), new TimeSpan(12, 34, 56)), +#endif + new ColumnNetTypeArrowTypeValue("sample_timestamp", typeof(DateTimeOffset), typeof(TimestampType), new DateTimeOffset(new DateTime(2024, 1, 1, 12, 34, 56), TimeSpan.Zero)), + new ColumnNetTypeArrowTypeValue("sample_array", typeof(Int32Array), typeof(ListType), numbersArray), + new ColumnNetTypeArrowTypeValue("sample_struct", typeof(string), typeof(StructType), "{\"name\":\"Gnarly\",\"age\":7}"), + } + }); + + return sampleDataBuilder; + } + + private static SampleDataBuilder GetDuckDbSampleData() + { + SampleDataBuilder sampleDataBuilder = new SampleDataBuilder(); + + sampleDataBuilder.Samples.Add( + new SampleData() + { + Query = "SELECT " + + "42 AS \"TinyInt\", " + + "12345 AS \"SmallInt\", " + + "987654321 AS \"Integer\", " + + "1234567890123 AS \"BigInt\", " + + "3.141592 AS \"Real\", " + + "123.456789123456 AS \"Double\", " + + "DECIMAL '12345.67' AS \"Decimal\", " + + "'DuckDB' AS \"Varchar\", " + + "BLOB 'abc' AS \"Blob\", " + + "TRUE AS \"Boolean\"," + + "DATE '2024-09-10' AS \"Date\", " + + "TIME '12:34:56' AS \"Time\", " + + "TIMESTAMP '2024-09-10 12:34:56' AS \"Timestamp\", " + + "INTERVAL '1 year' AS \"Interval\", " + + "'[1, 2, 3]'::JSON AS \"JSON\", " + + "'[{\"key\": \"value\"}]'::JSON AS \"JSON_Array\", " + + "to_json([true, false, null]) AS \"List_JSON\", " + // need to convert List values to json + "to_json(MAP {'key': 'value'}) AS \"Map_JSON\" ", // need to convert Map values to json + ExpectedValues = new List() + { + new ColumnNetTypeArrowTypeValue("TinyInt", typeof(int), typeof(Int32Type), 42), + new ColumnNetTypeArrowTypeValue("SmallInt", typeof(int), typeof(Int32Type), 12345), + new ColumnNetTypeArrowTypeValue("Integer", typeof(int), typeof(Int32Type), 987654321), + new ColumnNetTypeArrowTypeValue("BigInt", typeof(Int64), typeof(Int64Type), 1234567890123), + new ColumnNetTypeArrowTypeValue("Real", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(3.141592m)), + new ColumnNetTypeArrowTypeValue("Double", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(123.456789123456m)), + new ColumnNetTypeArrowTypeValue("Decimal", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(12345.67m)), + new ColumnNetTypeArrowTypeValue("Varchar", typeof(string), typeof(StringType), "DuckDB"), + new ColumnNetTypeArrowTypeValue("Blob", typeof(byte[]), typeof(BinaryType), Encoding.UTF8.GetBytes("abc")), + new ColumnNetTypeArrowTypeValue("Boolean", typeof(bool), typeof(BooleanType), true), + new ColumnNetTypeArrowTypeValue("Date", typeof(DateTime), typeof(Date32Type), new DateTime(2024, 09, 10)), +#if NET6_0_OR_GREATER + new ColumnNetTypeArrowTypeValue("Time", typeof(TimeOnly), typeof(Time64Type), new TimeOnly(12, 34, 56)), +#else + new ColumnNetTypeArrowTypeValue("Time", typeof(TimeSpan), typeof(Time64Type), new TimeSpan(12, 34, 56)), +#endif + new ColumnNetTypeArrowTypeValue("Timestamp", typeof(DateTimeOffset), typeof(TimestampType), new DateTimeOffset(new DateTime(2024, 9, 10, 12, 34, 56), TimeSpan.Zero)), + new ColumnNetTypeArrowTypeValue("Interval", typeof(MonthDayNanosecondInterval), typeof(IntervalType), new MonthDayNanosecondInterval(12, 0, 0)), + new ColumnNetTypeArrowTypeValue("JSON", typeof(string), typeof(StringType), "[1, 2, 3]"), + new ColumnNetTypeArrowTypeValue("JSON_Array", typeof(string), typeof(StringType), "[{\"key\": \"value\"}]"), + new ColumnNetTypeArrowTypeValue("List_JSON", typeof(string), typeof(StringType),"[true,false,null]"), + new ColumnNetTypeArrowTypeValue("Map_JSON", typeof(string), typeof(StringType), "{\"key\":\"value\"}"), + } + } + ); + + return sampleDataBuilder; + } + + private static SampleDataBuilder GetSQLiteSampleData() + { + string tempTable = Guid.NewGuid().ToString().Replace("-", ""); + + SampleDataBuilder sampleDataBuilder = new SampleDataBuilder(); + + sampleDataBuilder.Samples.Add( + new SampleData() + { + // for SQLite, we can't just select data without a + // table because we get mixed schemas that are returned, + // resulting in an error. so create a temp table, + // insert data, select data, then remove the table. + PreQueryCommands = new List() + { + $"CREATE TEMP TABLE [{tempTable}] (INTEGER_COLUMN INTEGER, TEXT_COLUMN TEXT, BLOB_COLUMN BLOB, REAL_COLUMN REAL, NULL_COLUMN NULL);", + $"INSERT INTO [{tempTable}] (INTEGER_COLUMN, TEXT_COLUMN, BLOB_COLUMN, REAL_COLUMN, NULL_COLUMN) VALUES (42, 'Hello, SQLite', X'426C6F62', 3.14159, NULL);" + }, + Query = $"SELECT INTEGER_COLUMN, TEXT_COLUMN, BLOB_COLUMN, REAL_COLUMN, NULL_COLUMN FROM [{tempTable}];", + PostQueryCommands = new List() + { + $"DROP TABLE [{tempTable}];" + }, + ExpectedValues = new List() + { + new ColumnNetTypeArrowTypeValue("INTEGER_COLUMN", typeof(long), typeof(Int64Type), 42L), + new ColumnNetTypeArrowTypeValue("TEXT_COLUMN", typeof(string), typeof(StringType), "Hello, SQLite"), + new ColumnNetTypeArrowTypeValue("BLOB_COLUMN", typeof(byte[]), typeof(BinaryType), Encoding.UTF8.GetBytes("Blob")), + new ColumnNetTypeArrowTypeValue("REAL_COLUMN", typeof(double), typeof(DoubleType), 3.14159d), + new ColumnNetTypeArrowTypeValue("NULL_COLUMN", typeof(UnionType), typeof(UnionType), null), + } + } + ); + + return sampleDataBuilder; + } + } +} diff --git a/csharp/test/Drivers/Interop/FlightSql/FlightSqlParameters.cs b/csharp/test/Drivers/Interop/FlightSql/FlightSqlParameters.cs new file mode 100644 index 0000000000..f1d5e29ffa --- /dev/null +++ b/csharp/test/Drivers/Interop/FlightSql/FlightSqlParameters.cs @@ -0,0 +1,42 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql +{ + internal class FlightSqlParameters + { + public const string Uri = "uri"; + public const string OptionAuthorizationHeader = "adbc.flight.sql.authorization_header"; + public const string OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header."; + public const string OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch"; + public const string OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query"; + public const string OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"; + public const string OptionSSLSkipVerify = "adbc.flight.sql.client_option.tls_skip_verify"; + public const string OptionAuthority = "adbc.flight.sql.client_option.authority"; + public const string Username = "username"; + public const string Password = "password"; + + // not used, but also available: + //public const string OptionMTLSCertChain = "adbc.flight.sql.client_option.mtls_cert_chain"; + //public const string OptionMTLSPrivateKey = "adbc.flight.sql.client_option.mtls_private_key"; + //public const string OptionSSLOverrideHostname = "adbc.flight.sql.client_option.tls_override_hostname"; + //public const string OptionSSLRootCerts = "adbc.flight.sql.client_option.tls_root_certs"; + //public const string OptionWithBlock = "adbc.flight.sql.client_option.with_block"; + //public const string OptionWithMaxMsgSize = "adbc.flight.sql.client_option.with_max_msg_size"; + //public const string OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware"; + } +} diff --git a/csharp/test/Drivers/Interop/FlightSql/FlightSqlTestConfiguration.cs b/csharp/test/Drivers/Interop/FlightSql/FlightSqlTestConfiguration.cs new file mode 100644 index 0000000000..73422fd468 --- /dev/null +++ b/csharp/test/Drivers/Interop/FlightSql/FlightSqlTestConfiguration.cs @@ -0,0 +1,132 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql +{ + internal class FlightSqlTestConfiguration + { + /// + /// The file path location of the driver. + /// + [JsonPropertyName("driverPath")] + public string? DriverPath { get; set; } + + /// + /// The entrypoint of the driver. + /// + [JsonPropertyName("driverEntryPoint")] + public string? DriverEntryPoint { get; set; } + + /// + /// A comma separated list of testable environments. + /// + [JsonPropertyName("testEnvironments")] + public List TestableEnvironments { get; set; } = new List(); + + /// + /// The active test environment. + /// + [JsonPropertyName("environments")] + public Dictionary Environments { get; set; } = new Dictionary(); + } + + internal enum FlightSqlTestEnvironmentType + { + Denodo, + Dremio, + DuckDB, + SQLite + } + + internal class FlightSqlTestEnvironment : TestConfiguration + { + public FlightSqlTestEnvironment() + { + + } + + /// + /// The name of the environment. + /// + public string? Name { get; set; } + + [JsonConverter(typeof(JsonStringEnumConverter))] + [JsonPropertyName("type")] + public FlightSqlTestEnvironmentType EnvironmentType { get; set; } + + /// + /// The service URI. + /// + [JsonPropertyName("uri")] + public string? Uri { get; set; } + + /// + /// Additional headers to add to the gRPC call. + /// + [JsonPropertyName("headers")] + public Dictionary RPCCallHeaders { get; set; } = new Dictionary(); + + /// + /// Additional headers to add to the gRPC call. + /// + [JsonPropertyName("sqlFile")] + + public string? FlightSqlFile { get; set; } + + /// + /// The authorization header. + /// + [JsonPropertyName("authorization")] + public string? AuthorizationHeader { get; set; } + + [JsonPropertyName("timeoutFetch")] + public string? TimeoutFetch { get; set; } + + [JsonPropertyName("timeoutQuery")] + public string? TimeoutQuery { get; set; } + + [JsonPropertyName("timeoutUpdate")] + public string? TimeoutUpdate { get; set; } + + [JsonPropertyName("sslSkipVerify")] + public bool SSLSkipVerify { get; set; } + + [JsonPropertyName("authority")] + public string? Authority { get; set; } + + [JsonPropertyName("username")] + public string? Username { get; set; } + + [JsonPropertyName("password")] + public string? Password { get; set; } + + [JsonPropertyName("supportsWriteUpdate")] + public bool SupportsWriteUpdate { get; set; } = false; + + [JsonPropertyName("supportsCatalogs")] + public bool SupportsCatalogs { get; set; } = false; + + [JsonPropertyName("tableTypes")] + public List TableTypes { get; set; } = new List(); + + [JsonPropertyName("caseSensitive")] + public bool CaseSensitive { get; set; } = false; + } +} diff --git a/csharp/test/Drivers/Interop/FlightSql/FlightSqlTestingUtils.cs b/csharp/test/Drivers/Interop/FlightSql/FlightSqlTestingUtils.cs new file mode 100644 index 0000000000..f874ee56c6 --- /dev/null +++ b/csharp/test/Drivers/Interop/FlightSql/FlightSqlTestingUtils.cs @@ -0,0 +1,220 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Text.Json; +using Apache.Arrow.Adbc.Drivers.Interop.FlightSql; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.FlightSql +{ + class FlightSqlTestingUtils + { + internal const string FLIGHTSQL_INTEROP_TEST_CONFIG_VARIABLE = "FLIGHTSQL_INTEROP_TEST_CONFIG_FILE"; + + public static FlightSqlTestConfiguration LoadFlightSqlTestConfiguration(string? environmentVariable = null) + { + if(string.IsNullOrEmpty(environmentVariable)) + environmentVariable = FLIGHTSQL_INTEROP_TEST_CONFIG_VARIABLE; + + FlightSqlTestConfiguration? testConfiguration = null; + + if (!string.IsNullOrWhiteSpace(environmentVariable)) + { + string? environmentValue = Environment.GetEnvironmentVariable(environmentVariable); + + if (!string.IsNullOrWhiteSpace(environmentValue)) + { + if (File.Exists(environmentValue)) + { + // use a JSON file for the various settings + string json = File.ReadAllText(environmentValue); + + testConfiguration = JsonSerializer.Deserialize(json)!; + } + } + } + + if (testConfiguration == null) + throw new InvalidOperationException($"Cannot execute test configuration from environment variable `{environmentVariable}`"); + + return testConfiguration; + } + + internal static List GetTestEnvironments(FlightSqlTestConfiguration testConfiguration) + { + if (testConfiguration == null) + throw new ArgumentNullException(nameof(testConfiguration)); + + if (testConfiguration.Environments == null || testConfiguration.Environments.Count == 0) + throw new InvalidOperationException("There are no environments configured"); + + List environments = new List(); + + foreach (string environmentName in GetEnvironmentNames(testConfiguration.TestableEnvironments!)) + { + if (testConfiguration.Environments.TryGetValue(environmentName, out FlightSqlTestEnvironment? flightSqlTestEnvironment)) + { + if (flightSqlTestEnvironment != null) + { + flightSqlTestEnvironment.Name = environmentName; + environments.Add(flightSqlTestEnvironment); + } + } + } + + if (environments.Count == 0) + throw new InvalidOperationException("Could not find a configured Flight SQL environment"); + + return environments; + } + + private static List GetEnvironmentNames(List names) + { + if (names == null) + return new List(); + + return names; + } + + /// + /// Gets a the Snowflake ADBC driver with settings from the + /// . + /// + /// + /// + /// + internal static AdbcDriver GetAdbcDriver( + FlightSqlTestConfiguration testConfiguration, + FlightSqlTestEnvironment environment, + out Dictionary parameters + ) + { + // see https://arrow.apache.org/adbc/main/driver/flight_sql.html + + parameters = new Dictionary{}; + + if(!string.IsNullOrEmpty(environment.Uri)) + { + parameters.Add(FlightSqlParameters.Uri, environment.Uri!); + } + + foreach(string key in environment.RPCCallHeaders.Keys) + { + parameters.Add(FlightSqlParameters.OptionRPCCallHeaderPrefix + key, environment.RPCCallHeaders[key]); + } + + if (!string.IsNullOrEmpty(environment.AuthorizationHeader)) + { + parameters.Add(FlightSqlParameters.OptionAuthorizationHeader, environment.AuthorizationHeader!); + } + else + { + if (!string.IsNullOrEmpty(environment.Username) && !string.IsNullOrEmpty(environment.Password)) + { + parameters.Add(FlightSqlParameters.Username, environment.Username!); + parameters.Add(FlightSqlParameters.Password, environment.Password!); + } + } + + if (!string.IsNullOrEmpty(environment.TimeoutQuery)) + parameters.Add(FlightSqlParameters.OptionTimeoutQuery, environment.TimeoutQuery!); + + if (!string.IsNullOrEmpty(environment.TimeoutFetch)) + parameters.Add(FlightSqlParameters.OptionTimeoutFetch, environment.TimeoutFetch!); + + if (!string.IsNullOrEmpty(environment.TimeoutUpdate)) + parameters.Add(FlightSqlParameters.OptionTimeoutUpdate, environment.TimeoutUpdate!); + + if (environment.SSLSkipVerify) + parameters.Add(FlightSqlParameters.OptionSSLSkipVerify, Convert.ToString(environment.SSLSkipVerify).ToLowerInvariant()); + + if (!string.IsNullOrEmpty(environment.Authority)) + parameters.Add(FlightSqlParameters.OptionAuthority, environment.Authority!); + + Dictionary options = new Dictionary() { }; + AdbcDriver driver = GetFlightSqlAdbcDriver(testConfiguration); + + return driver; + } + + /// + /// Gets a the Flight SQL ADBC driver with settings from the + /// . + /// + /// + /// + /// + internal static AdbcDriver GetFlightSqlAdbcDriver( + FlightSqlTestConfiguration testConfiguration + ) + { + AdbcDriver driver; + + if (testConfiguration == null || string.IsNullOrEmpty(testConfiguration.DriverPath) || string.IsNullOrEmpty(testConfiguration.DriverEntryPoint)) + { + driver = FlightSqlDriverLoader.LoadDriver(); + } + else + { + driver = FlightSqlDriverLoader.LoadDriver(testConfiguration.DriverPath!, testConfiguration.DriverEntryPoint!); + } + + return driver; + } + + /// + /// Parses the queries from resources/FlightSqlData.sql + /// + /// + internal static string[] GetQueries(FlightSqlTestEnvironment environment) + { + StringBuilder content = new StringBuilder(); + + string[] sql = File.ReadAllLines(environment.FlightSqlFile!); + + Dictionary placeholderValues = new Dictionary() { + {"{ADBC_CATALOG}", environment.Metadata.Catalog }, + {"{ADBC_SCHEMA}", environment.Metadata.Schema }, + {"{ADBC_TABLE}", environment.Metadata.Table } + }; + + foreach (string line in sql) + { + if (!line.TrimStart().StartsWith("--")) + { + string modifiedLine = line; + + foreach (string key in placeholderValues.Keys) + { + if (modifiedLine.Contains(key)) + modifiedLine = modifiedLine.Replace(key, placeholderValues[key]); + } + + content.AppendLine(modifiedLine); + } + } + + string[] queries = content.ToString().Split(";".ToCharArray()).Where(x => x.Trim().Length > 0).ToArray(); + + return queries; + } + } +} diff --git a/csharp/test/Drivers/Interop/FlightSql/Resources/flightsqlconfig.json b/csharp/test/Drivers/Interop/FlightSql/Resources/flightsqlconfig.json new file mode 100644 index 0000000000..b8b82eaf12 --- /dev/null +++ b/csharp/test/Drivers/Interop/FlightSql/Resources/flightsqlconfig.json @@ -0,0 +1,54 @@ +{ + "driverPath": "", + "driverEntryPoint": "", + "datasourceKind": "", + "testEnvironments": [ "","" ], + "environments": { + "": { + "uri": "", + "sslSkipVerify": true, + "headers": { + "routing_tag": "", + "routing_queue": "" + }, + "supportsWriteUpdate": true, + "supportsCatalogs": true, + "type": "DuckDB", + "tableTypes": [ ], + "sqlFile": ".sql", + "metadata": { + "catalog": "", + "schema": "", + "table": "", + "expectedColumnCount": 0 + }, + "authorization": "", + "authority": "", + "query": "", + "expectedResults": 0 + }, + "": { + "uri": "", + "sslSkipVerify": false, + "headers": { + "routing_tag": "", + "routing_queue": "" + }, + "supportsWriteUpdate": false, + "supportsCatalogs": false, + "type": "Dremio", + "tableTypes": [], + "sqlFile": ".sql", + "metadata": { + "catalog": "", + "schema": "", + "table": "", + "expectedColumnCount": 0 + }, + "authorization": "", + "authority": "", + "query": "", + "expectedResults": 0 + } + } +} diff --git a/csharp/test/Drivers/Interop/FlightSql/readme.md b/csharp/test/Drivers/Interop/FlightSql/readme.md new file mode 100644 index 0000000000..f89fe4a5ff --- /dev/null +++ b/csharp/test/Drivers/Interop/FlightSql/readme.md @@ -0,0 +1,85 @@ + + +# Flight SQL +The Flight SQL tests leverage the interop nature of the C# ADBC library. These require the use of the [Flight SQL Go driver](https://github.com/apache/arrow-adbc/tree/main/go/adbc/driver/flightsql). You will need to compile the Go driver for your platform and place the driver in the correct path in order for the tests to execute correctly. + +To compile, navigate to the `go/adbc/pkg` directory of the cloned [arrow-adbc](https://github.com/apache/arrow-adbc) repository then run the `make` command. If you encounter compilation errors, please ensure that Go, [GCC and C++](https://code.visualstudio.com/docs/cpp/config-mingw) tools are installed. And following [Contributing to ADBC](https://github.com/apache/arrow-adbc/blob/main/CONTRIBUTING.md#environment-setup). + +## Setup +The environment variable `FLIGHTSQL_INTEROP_TEST_CONFIG_FILE` must be set to a configuration JSON file for the tests to execute. If it is not, the tests will show as passed with an output message that they are skipped. A template configuration file can be found in the Resources directory. + +## Configuration +A growing number of data sources support Arrow Flight SQL. This library has tests that run against: + +- [Denodo](https://community.denodo.com/docs/html/browse/9.1/en/vdp/developer/access_through_flight_sql/connection_using_flight_sql/connection_using_flight_sql) +- [Dremio](https://docs.dremio.com/current/sonar/developing-client-apps/arrow-flight-sql/) +- [DuckDB](https://github.com/voltrondata/SQLFlite) +- [SQLite](https://github.com/voltrondata/SQLFlite) + +It is recommended you test your data source with the Flight SQL Go driver to ensure compatibilty, since each data source can implement the Flight protocol slightly differently. + +A sample configuration file is provided in the Resources directory. The configuration file is a JSON file that contains the following fields: + +- **uri**: The endpoint for the service +- **username**: User name to use for authentication +- **password**: Password to use for authentication +- **sslSkipVerify**: "adbc.flight.sql.client_option.authority", +- **headers**: Key/value pairs of additional headers to include with the request. +- **supportsWriteUpdate**: Indicates whether the data source supports creating new tables +- **supportsCatalogs**: Indicates whether the data source supports catalog names +- **type**: Specifies the type of data source used for running data from FlightSqlData. The supported types are: + - Dremio + - Denodo + - DuckDB + - SQLite +- **tableTypes**: The table types to include in the GetObjects call +- **sqlFile**: A path to a SQL file to run queries to test CRUD operations +- **metadata**: Used for the GetObjects calls + - **catalog**: The catalog name to use for the GetObjects call + - **schema**: The schema name to use for the GetObjects call + - **table**: The table name to use for the GetObjects call + - **expectedColumnCount**: The number of columns that should be returned +- **authorization**: Used to set the `adbc.flight.sql.authorization_header` property +- **authority**: Used to set the `adbc.flight.sql.client_option.authority` property +- **query**: Select query run against the data source, +- **expectedResults**: Number of resutls expected from the query + +The configuration file supports targeting multiple data sources +simultaneously. To use multiple data sources, you can configure them like: + +``` + "testEnvironments**: [ + "Dremio_Remote", + "DuckDb_Local", + "SQLite_Local" + ], + "environments**: { + "SQLite_Local**: + { + ... + }, + "DuckDb_Local**: + { + ... + }, + "Dremio_Remote**: { + ... + } +``` diff --git a/csharp/test/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj b/csharp/test/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj index 488603d4b3..f660e7a645 100644 --- a/csharp/test/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj +++ b/csharp/test/Drivers/Interop/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj @@ -16,13 +16,13 @@ - - - + + + all runtime; build; native; contentfiles; analyzers; buildtransitive - + diff --git a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs index 04cf0d1eb8..b5f0457401 100644 --- a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs +++ b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs @@ -142,6 +142,37 @@ public void CanClientExecuteQueryWithNoResults() } } + // + /// Validates if the client can connect to a live server and execute a parameterized query. + /// + [SkippableFact, Order(4)] + public void CanClientExecuteParameterizedQuery() + { + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + testConfiguration.Query = "SELECT ? as A, ? as B, ? as C, * FROM (SELECT column1 FROM (VALUES (1), (2), (3))) WHERE column1 < ?"; + testConfiguration.ExpectedResultsCount = 1; + + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) + { + Tests.ClientTests.CanClientExecuteQuery(adbcConnection, testConfiguration, command => + { + DbParameter CreateParameter(DbType dbType, object value) + { + DbParameter result = command.CreateParameter(); + result.DbType = dbType; + result.Value = value; + return result; + } + + // TODO: Add tests for decimal and time once supported by the driver or gosnowflake + command.Parameters.Add(CreateParameter(DbType.Int32, 2)); + command.Parameters.Add(CreateParameter(DbType.String, "text")); + command.Parameters.Add(CreateParameter(DbType.Double, 2.5)); + command.Parameters.Add(CreateParameter(DbType.Int32, 2)); + }); + } + } + // /// Validates if the client can connect to a live server /// and parse the results. diff --git a/csharp/test/Drivers/Interop/Snowflake/ConstraintTests.cs b/csharp/test/Drivers/Interop/Snowflake/ConstraintTests.cs index 19a71cdc4d..c57eb636a6 100644 --- a/csharp/test/Drivers/Interop/Snowflake/ConstraintTests.cs +++ b/csharp/test/Drivers/Interop/Snowflake/ConstraintTests.cs @@ -62,7 +62,7 @@ public void CanGetObjectsTableConstraintsWithColumnNameFilter(string constraintT using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); List? tables = catalogs .Where(c => string.Equals(c.Name, databaseName)) @@ -110,7 +110,7 @@ public void CanGetObjectsTableConstraints(string constraintType, string constrai using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); List? tables = catalogs .Where(c => string.Equals(c.Name, databaseName)) diff --git a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs index 4e9fbb4f41..56f1938573 100644 --- a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs +++ b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs @@ -87,6 +87,80 @@ public DriverTests() _connection = _database.Connect(options); } + /// + /// Validates if the DEFAULT_ROLE works correctly for ADBC. + /// + [SkippableFact, Order(1)] + public void ValidateUserRole() + { + Skip.If(_testConfiguration.RoleInfo == null); + + // first test with the DEFAULT_ROLE value + Assert.True(CurrentRoleIsExpectedRole(_connection, _testConfiguration.RoleInfo.DefaultRole)); ; + + // now connect with the new role and ensure we get the new role successfully + Dictionary parameters = new Dictionary(); + Dictionary options = new Dictionary(); + + using AdbcDriver localSnowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(_testConfiguration, out parameters); + parameters.Add(SnowflakeParameters.ROLE, _testConfiguration.RoleInfo.NewRole); + + using AdbcDatabase localDatabase = localSnowflakeDriver.Open(parameters); + using AdbcConnection localConnection = localDatabase.Connect(options); + Assert.True(CurrentRoleIsExpectedRole(localConnection, _testConfiguration.RoleInfo.NewRole)); + } + + private bool CurrentRoleIsExpectedRole(AdbcConnection cn, string expectedRole) + { + using AdbcStatement statement = cn.CreateStatement(); + statement.SqlQuery = "SELECT CURRENT_ROLE() as CURRENT_ROLE;"; + + QueryResult queryResult = statement.ExecuteQuery(); + using RecordBatch? recordBatch = queryResult.Stream?.ReadNextRecordBatchAsync().Result; + Assert.True(recordBatch != null); + + StringArray stringArray = (StringArray)recordBatch.Column("CURRENT_ROLE"); + Assert.True(stringArray.Length > 0); + + return expectedRole == stringArray.GetString(0); + } + + [SkippableFact, Order(1)] + public void CanSetDatabase() + { + Skip.If(string.IsNullOrEmpty(_testConfiguration.Metadata.Catalog)); + + // connect without the parameter and ensure we get the DATABASE successfully + Dictionary parameters = new Dictionary(); + Dictionary options = new Dictionary(); + + using AdbcDriver localSnowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(_testConfiguration, out parameters); + parameters.Remove(SnowflakeParameters.DATABASE); + using AdbcDatabase localDatabase = localSnowflakeDriver.Open(parameters); + using AdbcConnection localConnection = localDatabase.Connect(options); + + localConnection.SetOption(AdbcOptions.Connection.CurrentCatalog, _testConfiguration.Metadata.Catalog); + + Assert.True(CurrentDatabaseIsExpectedCatalog(localConnection, _testConfiguration.Metadata.Catalog)); + + localConnection.GetObjects(AdbcConnection.GetObjectsDepth.All, _testConfiguration.Metadata.Catalog, _testConfiguration.Metadata.Schema, _testConfiguration.Metadata.Table, _tableTypes, null); + } + + private bool CurrentDatabaseIsExpectedCatalog(AdbcConnection cn, string expectedCatalog) + { + using AdbcStatement statement = cn.CreateStatement(); + statement.SqlQuery = "SELECT CURRENT_DATABASE() as CURRENT_DATABASE;"; // GetOption doesn't exist in 1.0, only 1.1 + + QueryResult queryResult = statement.ExecuteQuery(); + using RecordBatch? recordBatch = queryResult.Stream?.ReadNextRecordBatchAsync().Result; + Assert.True(recordBatch != null); + + StringArray stringArray = (StringArray)recordBatch.Column("CURRENT_DATABASE"); + Assert.True(stringArray.Length > 0); + + return expectedCatalog == stringArray.GetString(0); + } + /// /// Validates if the driver can connect to a live server and /// parse the results. @@ -155,7 +229,7 @@ public void CanGetObjectsCatalogs(string catalogPattern) using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, null); + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, null); AdbcCatalog? catalog = catalogs.Where((catalog) => string.Equals(catalog.Name, databaseName)).FirstOrDefault(); Assert.True(catalog != null, "catalog should not be null"); @@ -182,7 +256,7 @@ public void CanGetObjectsDbSchemas(string dbSchemaPattern) using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); List? dbSchemas = catalogs .Where(c => string.Equals(c.Name, databaseName)) @@ -215,7 +289,7 @@ public void CanGetObjectsTables(string tableNamePattern) using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); List? tables = catalogs .Where(c => string.Equals(c.Name, databaseName)) @@ -252,7 +326,7 @@ public void CanGetObjectsAll() using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); AdbcTable? table = catalogs .Where(c => string.Equals(c.Name, databaseName)) .Select(c => c.DbSchemas) @@ -318,7 +392,7 @@ public void CanGetObjectsTablesWithSpecialCharacter(string databaseName, string using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, schemaName); List? tables = catalogs .Where(c => string.Equals(c.Name, databaseName)) @@ -399,6 +473,24 @@ public void CanExecuteQuery() Tests.DriverTests.CanExecuteQuery(queryResult, _testConfiguration.ExpectedResultsCount); } + /// + /// Validates if the driver can connect to a live server and execute a parameterized query. + /// + [SkippableFact, Order(6)] + public void CanExecuteParameterizedQuery() + { + using AdbcStatement statement = _connection.CreateStatement(); + statement.SqlQuery = "SELECT * FROM (SELECT column1 FROM (VALUES (1), (2), (3))) WHERE column1 < ?"; + + Schema parameterSchema = new Schema(new[] { new Field("column1", Int32Type.Default, false) }, null); + RecordBatch parameters = new RecordBatch(parameterSchema, new[] { new Int32Array.Builder().Append(2).Build() }, 1); + statement.Bind(parameters, parameterSchema); + + QueryResult queryResult = statement.ExecuteQuery(); + + Tests.DriverTests.CanExecuteQuery(queryResult, 1); + } + [SkippableFact, Order(7)] public void CanIngestData() { diff --git a/csharp/test/Drivers/Interop/Snowflake/Resources/snowflakeconfig.json b/csharp/test/Drivers/Interop/Snowflake/Resources/snowflakeconfig.json index dec0c492e3..70cd008ef7 100644 --- a/csharp/test/Drivers/Interop/Snowflake/Resources/snowflakeconfig.json +++ b/csharp/test/Drivers/Interop/Snowflake/Resources/snowflakeconfig.json @@ -26,6 +26,10 @@ "password": "" } }, + "roleInfo": { + "defaultRole": "", + "newRole": "" + }, "query": "", "expectedResults": 0 } diff --git a/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestConfiguration.cs b/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestConfiguration.cs index 7c8a054925..86968aa798 100644 --- a/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestConfiguration.cs +++ b/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestConfiguration.cs @@ -84,6 +84,11 @@ internal class SnowflakeTestConfiguration : TestConfiguration [JsonPropertyName("authentication")] public SnowflakeAuthentication Authentication { get; set; } = new SnowflakeAuthentication(); + /// + /// The snowflake Authentication + /// + [JsonPropertyName("roleInfo")] + public RoleInfo? RoleInfo { get; set; } } public class SnowflakeAuthentication @@ -134,4 +139,13 @@ public class DefaultAuthentication [JsonPropertyName("password")] public string Password { get; set; } = string.Empty; } + + public class RoleInfo + { + [JsonPropertyName("defaultRole")] + public string DefaultRole { get; set; } = string.Empty; + + [JsonPropertyName("newRole")] + public string NewRole { get; set; } = string.Empty; + } } diff --git a/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestingUtils.cs b/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestingUtils.cs index 06dd3d9676..f7943a5e43 100644 --- a/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestingUtils.cs +++ b/csharp/test/Drivers/Interop/Snowflake/SnowflakeTestingUtils.cs @@ -33,6 +33,7 @@ internal class SnowflakeParameters public const string ACCOUNT = "adbc.snowflake.sql.account"; public const string USERNAME = "username"; public const string PASSWORD = "password"; + public const string ROLE = "adbc.snowflake.sql.role"; public const string WAREHOUSE = "adbc.snowflake.sql.warehouse"; public const string AUTH_TYPE = "adbc.snowflake.sql.auth_type"; public const string AUTH_TOKEN = "adbc.snowflake.sql.client_option.auth_token"; diff --git a/csharp/test/SmokeTests/Apache.Arrow.Adbc.SmokeTests/Apache.Arrow.Adbc.SmokeTests.csproj b/csharp/test/SmokeTests/Apache.Arrow.Adbc.SmokeTests/Apache.Arrow.Adbc.SmokeTests.csproj index f5c5ef8baa..20eb51ce40 100644 --- a/csharp/test/SmokeTests/Apache.Arrow.Adbc.SmokeTests/Apache.Arrow.Adbc.SmokeTests.csproj +++ b/csharp/test/SmokeTests/Apache.Arrow.Adbc.SmokeTests/Apache.Arrow.Adbc.SmokeTests.csproj @@ -30,7 +30,7 @@ - + diff --git a/csharp/test/SmokeTests/Interop/FlightSql/Apache.Arrow.Adbc.SmokeTests.Drivers.Interop.FlightSql.csproj b/csharp/test/SmokeTests/Interop/FlightSql/Apache.Arrow.Adbc.SmokeTests.Drivers.Interop.FlightSql.csproj new file mode 100644 index 0000000000..af8e633436 --- /dev/null +++ b/csharp/test/SmokeTests/Interop/FlightSql/Apache.Arrow.Adbc.SmokeTests.Drivers.Interop.FlightSql.csproj @@ -0,0 +1,33 @@ + + + + net472;net6.0 + + + + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + PreserveNewest + + + + + + diff --git a/csharp/test/SmokeTests/build.props b/csharp/test/SmokeTests/build.props index 2b3620f03f..8b7221969a 100644 --- a/csharp/test/SmokeTests/build.props +++ b/csharp/test/SmokeTests/build.props @@ -4,5 +4,6 @@ $(Version) $(Version) $(Version) + $(Version) diff --git a/dev/bench/README.md b/dev/bench/README.md index 1c332885dc..63db5c8d42 100644 --- a/dev/bench/README.md +++ b/dev/bench/README.md @@ -28,3 +28,15 @@ functions for testing the ADBC Snowflake driver, the [snowflake-python-connector If `matplotlib` is installed, it will also draw the timing and memory usage up as charts which can be saved. + +# ODBC benchmark + +The file odbc/main.cc contains code to utilize an ODBC driver and the +BindCol interface in order to perform a simple query and retrieve data. +This was used for benchmarking against Snowflake to compare with the ADBC +Snowflake driver. + +It can be built by simply using `cmake` as long as you have unixODBC or +another ODBC library that can be found by `cmake` for building. After +building the mainprog, it can be run with a single argument being the ODBC +DSN to use such as "DSN=snowflake;UID=;PWD=;". diff --git a/dev/bench/odbc/CMakeLists.txt b/dev/bench/odbc/CMakeLists.txt new file mode 100644 index 0000000000..eb3e533b8f --- /dev/null +++ b/dev/bench/odbc/CMakeLists.txt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.21) +project(odbcbench LANGUAGES C CXX) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(ODBC) +add_executable(odbcbench main.cc) +target_link_libraries(odbcbench ODBC::ODBC) +set_target_properties(odbcbench PROPERTIES BUILD_RPATH "\$ORIGIN" INSTALL_RPATH + "\$ORIGIN") diff --git a/dev/bench/odbc/main.cc b/dev/bench/odbc/main.cc new file mode 100644 index 0000000000..1f574629fe --- /dev/null +++ b/dev/bench/odbc/main.cc @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define ERRMSG_LEN 200 + +SQLINTEGER checkError(SQLRETURN rc, SQLSMALLINT handleType, SQLHANDLE handle, + SQLWCHAR* errmsg) { + SQLRETURN retcode = SQL_SUCCESS; + + SQLSMALLINT errNum = 1; + SQLCHAR sqlState[6]; // always exactly 5 characters + NUL + SQLINTEGER nativeError; + SQLCHAR errMsg[ERRMSG_LEN]; + SQLSMALLINT textLengthPtr; + + if ((rc != SQL_SUCCESS) && (rc != SQL_SUCCESS_WITH_INFO) && (rc != SQL_NO_DATA)) { + SQLLEN numRecs = 0; + SQLGetDiagField(SQL_HANDLE_STMT, handle, 0, SQL_DIAG_NUMBER, &numRecs, 0, 0); + while (retcode != SQL_NO_DATA) { + retcode = SQLGetDiagRecA(handleType, handle, errNum, sqlState, &nativeError, errMsg, + ERRMSG_LEN, &textLengthPtr); + + if (retcode == SQL_INVALID_HANDLE) { + std::cerr << "checkError function was called with an invalid handle!!" + << std::endl; + return 1; + } + + if ((retcode != SQL_SUCCESS) && (retcode != SQL_SUCCESS_WITH_INFO)) { + wprintf(L"ERROR: %d: %ls : %ls\n", nativeError, sqlState, errMsg); + } + + errNum++; + } + + wprintf(L"%ls\n", errmsg); + return 1; + } + + return 0; +} + +#define CHECK_OK(EXPR, ERROR) \ + do { \ + auto ret = (EXPR); \ + if (checkError(ret, SQL_HANDLE_DBC, dbc, (SQLWCHAR*)L##ERROR)) { \ + exit(1); \ + } \ + } while (0) + +int main(int argc, char** argv) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + SQLHSTMT stmt1; + + if (argc != 2) { + std::cerr << "Expected exactly 1 argument: the DSN for connecting, got " << argc - 1 + << "arguments. exiting..."; + return 1; + } + + std::string dsn(argv[1]); + + SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); + // we want ODBC3 support, set env handle + SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast(SQL_OV_ODBC3_80), 0); + // allocate connection handle + SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc); + + // connect to the DSN using dbc handle + CHECK_OK(SQLDriverConnectA(dbc, nullptr, + reinterpret_cast(const_cast(dsn.c_str())), + SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE), + "Error -- Driver Connect failed"); + + std::ofstream timing_output("odbc_perf_record"); + + for (size_t iter = 0; iter < 100; iter++) { + CHECK_OK(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), + "Error -- Statement handle alloc failed"); + + const auto start{std::chrono::steady_clock::now()}; + CHECK_OK(SQLExecDirect( + stmt, + (SQLCHAR*)"SELECT * FROM SNOWFLAKE_SAMPLE_DATA.TPCH_SF1000.LINEITEM " + "LIMIT 100000000", + SQL_NTS), + "Error -- Statement execution failed"); + + const SQLULEN bulkSize = 10000; + CHECK_OK(SQLSetStmtAttr(stmt, SQL_ATTR_ROW_ARRAY_SIZE, + reinterpret_cast(bulkSize), 0), + "Error -- SetStmtAttr failed"); + + // bind columns to buffers + SQLINTEGER val_orderkey[bulkSize]; + SQLINTEGER val_partkey[bulkSize]; + SQLINTEGER val_suppkey[bulkSize]; + SQLINTEGER val_linenumber[bulkSize]; + SQLDOUBLE val_quantity[bulkSize]; + SQLDOUBLE val_extendedprice[bulkSize]; + SQLDOUBLE val_discount[bulkSize]; + SQLDOUBLE val_tax[bulkSize]; + SQLCHAR val_retflag[bulkSize][2]; + SQLCHAR val_linestatus[bulkSize][2]; + SQL_DATE_STRUCT val_shipdate[bulkSize]; + SQL_DATE_STRUCT val_commitdate[bulkSize]; + SQL_DATE_STRUCT val_receiptdate[bulkSize]; + SQLCHAR val_shipinstruct[bulkSize][26]; + SQLCHAR val_shipmode[bulkSize][11]; + SQLCHAR val_comment[bulkSize][45]; + + CHECK_OK(SQLBindCol(stmt, 1, SQL_C_LONG, reinterpret_cast(val_orderkey), + sizeof(val_orderkey), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 2, SQL_C_LONG, reinterpret_cast(val_partkey), + sizeof(val_partkey), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 3, SQL_C_LONG, reinterpret_cast(val_suppkey), + sizeof(val_suppkey), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 4, SQL_C_LONG, reinterpret_cast(val_linenumber), + sizeof(val_linenumber), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 5, SQL_C_DOUBLE, reinterpret_cast(val_quantity), + sizeof(val_quantity), nullptr), + "BindCol failed"); + CHECK_OK( + SQLBindCol(stmt, 6, SQL_C_DOUBLE, reinterpret_cast(val_extendedprice), + sizeof(val_extendedprice), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 7, SQL_C_DOUBLE, reinterpret_cast(val_discount), + sizeof(val_discount), nullptr), + "BindCol failed"); + CHECK_OK( + SQLBindCol(stmt, 8, SQL_C_DOUBLE, (SQLPOINTER)val_tax, sizeof(val_tax), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 9, SQL_C_CHAR, (SQLPOINTER)val_retflag, + sizeof(val_retflag[0]), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 10, SQL_C_CHAR, (SQLPOINTER)val_linestatus, + sizeof(val_linestatus[0]), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 11, SQL_C_DATE, (SQLPOINTER)val_shipdate, + sizeof(val_shipdate), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 12, SQL_C_DATE, (SQLPOINTER)val_commitdate, + sizeof(val_commitdate), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 13, SQL_C_DATE, (SQLPOINTER)val_receiptdate, + sizeof(val_receiptdate), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 14, SQL_C_CHAR, (SQLPOINTER)val_shipinstruct, + sizeof(val_shipinstruct[0]), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 15, SQL_C_CHAR, (SQLPOINTER)val_shipmode, + sizeof(val_shipmode[0]), nullptr), + "BindCol failed"); + CHECK_OK(SQLBindCol(stmt, 16, SQL_C_CHAR, (SQLPOINTER)val_comment, + sizeof(val_comment[0]), nullptr), + "BindCol failed"); + + SQLRETURN ret; + while (true) { + ret = SQLFetch(stmt); + if (checkError(ret, SQL_HANDLE_DBC, dbc, (SQLWCHAR*)L"fetch failed")) { + exit(1); + } + if (ret == SQL_NO_DATA) break; + } + + const auto end{std::chrono::steady_clock::now()}; + const std::chrono::duration elapsed{end - start}; + timing_output << elapsed.count() << std::endl; + + SQLFreeStmt(stmt, SQL_CLOSE); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + if (iter % 10 == 0) { + std::cout << "Run " << iter << std::endl; + std::cout << "\tRuntime: " << elapsed.count() << " s" << std::endl; + } + } + + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/dev/release/01-prepare.sh b/dev/release/01-prepare.sh index 279e764b68..8767e91d8f 100755 --- a/dev/release/01-prepare.sh +++ b/dev/release/01-prepare.sh @@ -55,8 +55,7 @@ main() { # it by hand. ( echo ; - # Strip trailing blank line - printf '%s\n' "$(cz ch --dry-run --unreleased-version "ADBC Libraries ${RELEASE}" --start-rev apache-arrow-adbc-${PREVIOUS_RELEASE})" + changelog ) >> ${SOURCE_DIR}/../../CHANGELOG.md git add ${SOURCE_DIR}/../../CHANGELOG.md git commit -m "chore: update CHANGELOG.md for ${RELEASE}" diff --git a/dev/release/02-sign.sh b/dev/release/02-sign.sh index 1bc0673949..728407dfd4 100755 --- a/dev/release/02-sign.sh +++ b/dev/release/02-sign.sh @@ -61,9 +61,7 @@ main() { --skip-existing header "Adding release notes" - # XXX: commitizen likes to include the entire history even if we - # give it a tag, so we have to give it both tags explicitly - local -r release_notes=$(cz ch --dry-run --unreleased-version "ADBC Libraries ${RELEASE}" --start-rev apache-arrow-adbc-${PREVIOUS_RELEASE}) + local -r release_notes=$(changelog) echo "${release_notes}" gh release edit \ "${tag}" \ diff --git a/dev/release/post-05-linux.sh b/dev/release/post-05-linux.sh index 2cc6a3a4f3..5a4568b428 100755 --- a/dev/release/post-05-linux.sh +++ b/dev/release/post-05-linux.sh @@ -41,7 +41,7 @@ main() { export DEPLOY_ALMALINUX=${DEPLOY_ALMALINUX:-1} export DEPLOY_DEBIAN=${DEPLOY_DEBIAN:-1} export DEPLOY_UBUNTU=${DEPLOY_UBUNTU:-1} - "${arrow_dir}/dev/release/post-02-binary.sh" "${RELEASE}" "${rc_number}" + "${arrow_dir}/dev/release/post-03-binary.sh" "${RELEASE}" "${rc_number}" } main "$@" diff --git a/dev/release/post-08-rust.sh b/dev/release/post-08-rust.sh new file mode 100755 index 0000000000..e31427d5b7 --- /dev/null +++ b/dev/release/post-08-rust.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# -*- indent-tabs-mode: nil; sh-indentation: 2; sh-basic-offset: 2 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +set -e +set -u +set -o pipefail + +SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +source "${SOURCE_DIR}/utils-common.sh" +source "${SOURCE_DIR}/utils-prepare.sh" + +main() { + if [ "$#" -ne 0 ]; then + echo "Usage: $0" + exit + fi + + local -r tag="apache-arrow-adbc-${RELEASE}" + # Ensure we are being run from the tag + if [[ $(git describe --exact-match --tags) != "${tag}" ]]; then + echo "This script must be run from the tag ${tag}" + exit 1 + fi + + pushd "${SOURCE_TOP_DIR}/rust" + cargo publish --all-features -p adbc_core + popd + + echo "Success! The released Cargo crate is available here:" + echo " https://crates.io/crates/adbc_core" +} + +main "$@" diff --git a/dev/release/post-09-announce.sh b/dev/release/post-09-announce.sh new file mode 100755 index 0000000000..c77bf7bc96 --- /dev/null +++ b/dev/release/post-09-announce.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +set -ue + +SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +source "${SOURCE_DIR}/utils-common.sh" +source "${SOURCE_DIR}/utils-prepare.sh" + +main() { + + set_resolved_issues "${RELEASE}" + + cat < -The Apache Arrow team is pleased to announce the ${RELEASE} release of -the Apache Arrow ADBC libraries. This covers includes [**${RESOLVED_ISSUES} +The Apache Arrow team is pleased to announce the version ${RELEASE} release of +the Apache Arrow ADBC libraries. This release includes [**${RESOLVED_ISSUES} resolved issues**][1] from [**${contributors} distinct contributors**][2]. -This is a release of the **libraries**, which are at version -${RELEASE}. The **API specification** is versioned separately and is -at version ${spec_version}. +This is a release of the **libraries**, which are at version ${RELEASE}. The +[**API specification**][specification] is versioned separately and is at +version ${spec_version}. The subcomponents are versioned independently: @@ -112,6 +112,7 @@ or the [Arrow mailing lists][5]. [3]: https://github.com/apache/arrow-adbc/blob/apache-arrow-adbc-${RELEASE}/CHANGELOG.md [4]: https://github.com/apache/arrow-adbc/issues [5]: {% link community.md %} +[specification]: https://arrow.apache.org/adbc/${RELEASE}/format/specification.html EOF } diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 459a272b7f..7ffcfc13e2 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -1,6 +1,10 @@ *.json *.Rproj *.Rd +c/subprojects/fmt.wrap +c/subprojects/gtest.wrap +c/subprojects/nanoarrow.wrap +c/subprojects/sqlite3.wrap c/vendor/backward/* c/vendor/fmt/* c/vendor/sqlite3/sqlite3.c @@ -17,7 +21,7 @@ dev/release/rat_exclude_files.txt docs/source/format/*.drawio docs/source/format/*.svg filtered_rat.txt -go/adbc/drivermgr/adbc.h +go/adbc/drivermgr/arrow-adbc/adbc.h go/adbc/drivermgr/adbc_driver_manager.cc go/adbc/drivermgr/adbc_driver_manager.h go/adbc/status_string.go diff --git a/dev/release/utils-common.sh b/dev/release/utils-common.sh index 2cec70b667..99c9b66331 100644 --- a/dev/release/utils-common.sh +++ b/dev/release/utils-common.sh @@ -28,6 +28,7 @@ if [[ ! -f "${SOURCE_DIR}/.env" ]]; then fi source "${SOURCE_DIR}/.env" +source "${SOURCE_DIR}/versions.env" header() { echo "============================================================" @@ -35,6 +36,24 @@ header() { echo "============================================================" } +changelog() { + # Strip trailing blank line + local -r changelog=$(printf '%s\n' "$(cz ch --dry-run --unreleased-version "ADBC Libraries ${RELEASE}" --start-rev apache-arrow-adbc-${PREVIOUS_RELEASE})") + # Split off header + local -r header=$(echo "${changelog}" | head -n 1) + local -r trailer=$(echo "${changelog}" | tail -n+2) + echo "${header}" + echo + echo "### Versions" + echo + echo "- C/C++/GLib/Go/Python/Ruby: ${VERSION_NATIVE}" + echo "- C#: ${VERSION_CSHARP}" + echo "- Java: ${VERSION_JAVA}" + echo "- R: ${VERSION_R}" + echo "- Rust: ${VERSION_RUST}" + echo "${trailer}" +} + header "Config" echo "Repository: ${REPOSITORY}" diff --git a/dev/release/utils-prepare.sh b/dev/release/utils-prepare.sh index 52639cb045..43292d3771 100644 --- a/dev/release/utils-prepare.sh +++ b/dev/release/utils-prepare.sh @@ -28,6 +28,7 @@ update_versions() { case ${type} in release) local c_version="${VERSION_NATIVE}" + local csharp_suffix="" local docs_version="${RELEASE}" local glib_version="${VERSION_NATIVE}" local java_version="${VERSION_JAVA}" @@ -36,6 +37,7 @@ update_versions() { ;; snapshot) local c_version="${VERSION_NATIVE}-SNAPSHOT" + local csharp_suffix="SNAPSHOT" local docs_version="${RELEASE} (dev)" local glib_version="${VERSION_NATIVE}-SNAPSHOT" local java_version="${VERSION_JAVA}-SNAPSHOT" @@ -64,6 +66,11 @@ update_versions() { sed -i.bak -E "s/set\(ADBC_VERSION \".+\"\)/set(ADBC_VERSION \"${c_version}\")/g" cmake_modules/AdbcVersion.cmake rm cmake_modules/AdbcVersion.cmake.bak git add cmake_modules/AdbcVersion.cmake + + # Avoid changing meson_version + sed -i.bak -E "s/ version: '.+',/ version: '${c_version}',/g" meson.build + rm meson.build.bak + git add meson.build popd pushd "${ADBC_DIR}/ci/conda/" @@ -72,7 +79,13 @@ update_versions() { git add meta.yaml popd - sed -i.bak -E "s|.+|${csharp_version}|" "${ADBC_DIR}/csharp/Directory.Build.props" + sed -i.bak \ + -E "s|.+|${csharp_version}|" \ + "${ADBC_DIR}/csharp/Directory.Build.props" + rm "${ADBC_DIR}/csharp/Directory.Build.props.bak" + sed -i.bak \ + -E "s|.+|${csharp_suffix}|" \ + "${ADBC_DIR}/csharp/Directory.Build.props" rm "${ADBC_DIR}/csharp/Directory.Build.props.bak" git add "${ADBC_DIR}/csharp/Directory.Build.props" diff --git a/dev/release/verify-release-candidate.ps1 b/dev/release/verify-release-candidate.ps1 index 2425f463cc..13f289f93c 100755 --- a/dev/release/verify-release-candidate.ps1 +++ b/dev/release/verify-release-candidate.ps1 @@ -140,6 +140,7 @@ New-Item -ItemType Directory -Force -Path $CppBuildDir | Out-Null # XXX(apache/arrow-adbc#634): not working on Windows due to it picking # up MSVC as the C compiler, which then blows up when /Werror gets # passed in by some package +$env:BUILD_DRIVER_BIGQUERY = "0" $env:BUILD_DRIVER_FLIGHTSQL = "0" $env:BUILD_DRIVER_SNOWFLAKE = "0" diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 825a450d44..a5b277560d 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -21,7 +21,7 @@ # Requirements # - Ruby >= 2.3 # - Maven >= 3.3.9 -# - JDK >=7 +# - JDK >= 11 # - gcc >= 4.8 # - Go >= 1.21 # - Docker @@ -239,6 +239,59 @@ setup_tempdir() { echo "Working in sandbox ${ARROW_TMPDIR}" } +install_dotnet() { + # Install C# if doesn't already exist + if [ "${DOTNET_ALREADY_INSTALLED:-0}" -gt 0 ]; then + show_info ".NET already installed $(which csharp) (.NET $(dotnet --version))" + return 0 + fi + + show_info "Ensuring that .NET is installed..." + + if dotnet --version | grep 8\.0 > /dev/null 2>&1; then + local csharp_bin=$(dirname $(which dotnet)) + show_info "Found C# at $(which csharp) (.NET $(dotnet --version))" + else + if which dotnet > /dev/null 2>&1; then + show_info "dotnet found but it is the wrong version and will be ignored." + fi + local csharp_bin=${ARROW_TMPDIR}/csharp/bin + local dotnet_version=8.0.204 + local dotnet_platform= + case "$(uname)" in + Linux) + dotnet_platform=linux + ;; + Darwin) + dotnet_platform=macos + ;; + esac + local dotnet_download_thank_you_url=https://dotnet.microsoft.com/download/thank-you/dotnet-sdk-${dotnet_version}-${dotnet_platform}-x64-binaries + local dotnet_download_url=$( \ + curl -sL ${dotnet_download_thank_you_url} | \ + grep 'directLink' | \ + grep -E -o 'https://download[^"]+' | \ + sed -n 2p) + mkdir -p ${csharp_bin} + curl -sL ${dotnet_download_url} | \ + tar xzf - -C ${csharp_bin} + PATH=${csharp_bin}:${PATH} + show_info "Installed C# at $(which csharp) (.NET $(dotnet --version))" + fi + + # Ensure to have sourcelink installed + if ! dotnet tool list | grep sourcelink > /dev/null 2>&1; then + dotnet new tool-manifest + dotnet tool install --local sourcelink + PATH=${csharp_bin}:${PATH} + if ! dotnet tool run sourcelink --help > /dev/null 2>&1; then + export DOTNET_ROOT=${csharp_bin} + fi + fi + + DOTNET_ALREADY_INSTALLED=1 +} + install_go() { # Install go if [ "${GO_ALREADY_INSTALLED:-0}" -gt 0 ]; then @@ -287,10 +340,49 @@ install_go() { GO_ALREADY_INSTALLED=1 } +install_rust() { + if [ "${RUST_ALREADY_INSTALLED:-0}" -gt 0 ]; then + show_info "Rust already installed at $(command -v cargo)" + show_info "$(cargo --version)" + return 0 + fi + + if [[ -f ${ARROW_TMPDIR}/cargo/env ]]; then + source ${ARROW_TMPDIR}/cargo/env + rustup default stable + show_info "$(cargo version) installed at $(command -v cargo)" + RUST_ALREADY_INSTALLED=1 + return 0 + fi + + if command -v cargo > /dev/null; then + show_info "Found $(cargo version) at $(command -v cargo)" + RUST_ALREADY_INSTALLED=1 + return 0 + fi + + show_info "Installing Rust..." + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs |\ + env \ + RUSTUP_HOME=${ARROW_TMPDIR}/rustup \ + CARGO_HOME=${ARROW_TMPDIR}/cargo \ + sh -s -- \ + --default-toolchain stable \ + --no-modify-path \ + -y + + source ${ARROW_TMPDIR}/cargo/env + rustup default stable + + show_info "$(cargo version) installed at $(command -v cargo)" + + RUST_ALREADY_INSTALLED=1 +} + install_conda() { # Setup short-lived miniconda for Python and integration tests show_info "Ensuring that Conda is installed..." - local prefix=$ARROW_TMPDIR/mambaforge + local prefix=$ARROW_TMPDIR/miniforge # Setup miniconda only if the directory doesn't exist yet if [ "${CONDA_ALREADY_INSTALLED:-0}" -eq 0 ]; then @@ -298,7 +390,7 @@ install_conda() { show_info "Installing miniconda at ${prefix}..." local arch=$(uname -m) local platform=$(uname) - local url="https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-${platform}-${arch}.sh" + local url="https://github.com/conda-forge/miniforge/releases/latest/download/miniforge3-${platform}-${arch}.sh" curl -sL -o miniconda.sh $url bash miniconda.sh -b -p $prefix rm -f miniconda.sh @@ -318,9 +410,7 @@ install_conda() { maybe_setup_conda() { # Optionally setup conda environment with the passed dependencies local env="conda-${CONDA_ENV:-source}" - # XXX(https://github.com/apache/arrow-adbc/issues/1247): no duckdb for - # python 3.12 on conda-forge right now - local pyver=${PYTHON_VERSION:-3.11} + local pyver=${PYTHON_VERSION:-3} if [ "${USE_CONDA}" -gt 0 ]; then show_info "Configuring Conda environment..." @@ -348,6 +438,13 @@ maybe_setup_conda() { fi } +maybe_setup_dotnet() { + show_info "Ensuring that .NET is installed..." + if [ "${USE_CONDA}" -eq 0 ]; then + install_dotnet + fi +} + maybe_setup_virtualenv() { # Optionally setup pip virtualenv with the passed dependencies local env="venv-${VENV_ENV:-source}" @@ -404,22 +501,30 @@ maybe_setup_go() { fi } +maybe_setup_rust() { + show_info "Ensuring that Rust is installed..." + if [ "${USE_CONDA}" -eq 0 ]; then + install_rust + fi +} + test_cpp() { show_header "Build, install and test C++ libraries" # Build and test C++ maybe_setup_go + # XXX: pin Python for now since various other packages haven't caught up maybe_setup_conda \ --file ci/conda_env_cpp.txt \ compilers \ - go=1.21 || exit 1 + go=1.22 python=3.12 || exit 1 if [ "${USE_CONDA}" -gt 0 ]; then export CMAKE_PREFIX_PATH="${CONDA_BACKUP_CMAKE_PREFIX_PATH}:${CMAKE_PREFIX_PATH}" # The CMake setup forces RPATH to be the Conda prefix - local -r install_prefix="${CONDA_PREFIX}" + export CPP_INSTALL_PREFIX="${CONDA_PREFIX}" else - local -r install_prefix="${ARROW_TMPDIR}/local" + export CPP_INSTALL_PREFIX="${ARROW_TMPDIR}/local" fi export CMAKE_BUILD_PARALLEL_LEVEL=${CMAKE_BUILD_PARALLEL_LEVEL:-${NPROC}} @@ -428,14 +533,14 @@ test_cpp() { export ADBC_CMAKE_ARGS="-DADBC_INSTALL_NAME_RPATH=OFF" export ADBC_USE_ASAN=OFF export ADBC_USE_UBSAN=OFF - "${ADBC_DIR}/ci/scripts/cpp_build.sh" "${ADBC_SOURCE_DIR}" "${ARROW_TMPDIR}/cpp-build" "${install_prefix}" + "${ADBC_DIR}/ci/scripts/cpp_build.sh" "${ADBC_SOURCE_DIR}" "${ARROW_TMPDIR}/cpp-build" "${CPP_INSTALL_PREFIX}" # FlightSQL driver requires running database for testing export BUILD_DRIVER_FLIGHTSQL=0 # PostgreSQL driver requires running database for testing export BUILD_DRIVER_POSTGRESQL=0 # Snowflake driver requires snowflake creds for testing export BUILD_DRIVER_SNOWFLAKE=0 - "${ADBC_DIR}/ci/scripts/cpp_test.sh" "${ARROW_TMPDIR}/cpp-build" "${install_prefix}" + "${ADBC_DIR}/ci/scripts/cpp_test.sh" "${ARROW_TMPDIR}/cpp-build" "${CPP_INSTALL_PREFIX}" export BUILD_DRIVER_FLIGHTSQL=1 export BUILD_DRIVER_POSTGRESQL=1 export BUILD_DRIVER_SNOWFLAKE=1 @@ -456,7 +561,8 @@ test_python() { # Build and test Python maybe_setup_virtualenv cython duckdb pandas protobuf pyarrow pytest setuptools_scm setuptools importlib_resources || exit 1 - maybe_setup_conda --file "${ADBC_DIR}/ci/conda_env_python.txt" || exit 1 + # XXX: pin Python for now since various other packages haven't caught up + maybe_setup_conda --file "${ADBC_DIR}/ci/conda_env_python.txt" python=3.12 || exit 1 if [ "${USE_CONDA}" -gt 0 ]; then CMAKE_PREFIX_PATH="${CONDA_BACKUP_CMAKE_PREFIX_PATH}:${CMAKE_PREFIX_PATH}" @@ -538,9 +644,11 @@ test_glib() { test_csharp() { show_header "Build and test C# libraries" - install_csharp + maybe_setup_dotnet + maybe_setup_conda dotnet || exit 1 - echo "C♯ is not implemented" + "${ADBC_DIR}/ci/scripts/csharp_build.sh" "${ADBC_SOURCE_DIR}" + "${ADBC_DIR}/ci/scripts/csharp_test.sh" "${ADBC_SOURCE_DIR}" } test_js() { @@ -563,7 +671,7 @@ test_go() { # apache/arrow-adbc#517: `go build` calls git. Don't assume system # has git; even if it's there, go_build.sh sets DYLD_LIBRARY_PATH # which can interfere with system git. - maybe_setup_conda compilers git go=1.21 || exit 1 + maybe_setup_conda compilers git go=1.22 || exit 1 if [ "${USE_CONDA}" -gt 0 ]; then # The CMake setup forces RPATH to be the Conda prefix @@ -577,6 +685,17 @@ test_go() { "${ADBC_DIR}/ci/scripts/go_test.sh" "${ADBC_SOURCE_DIR}" "${ARROW_TMPDIR}/go-build" "${install_prefix}" } +test_rust() { + show_header "Build and test Rust libraries" + + maybe_setup_rust || exit 1 + maybe_setup_conda rust || exit 1 + + # We expect the C++ libraries to exist. + "${ADBC_DIR}/ci/scripts/rust_build.sh" "${ADBC_SOURCE_DIR}" + "${ADBC_DIR}/ci/scripts/rust_test.sh" "${ADBC_SOURCE_DIR}" "${CPP_INSTALL_PREFIX}" +} + ensure_source_directory() { show_header "Ensuring source directory" @@ -644,6 +763,9 @@ test_source_distribution() { if [ ${TEST_CPP} -gt 0 ]; then test_cpp fi + if [ ${TEST_CSHARP} -gt 0 ]; then + test_csharp + fi if [ ${TEST_GLIB} -gt 0 ]; then test_glib fi @@ -659,6 +781,9 @@ test_source_distribution() { if [ ${TEST_R} -gt 0 ]; then test_r fi + if [ ${TEST_RUST} -gt 0 ]; then + test_rust + fi popd } @@ -702,7 +827,7 @@ test_linux_wheels() { local arch="x86_64" fi - local python_versions="${TEST_PYTHON_VERSIONS:-3.9 3.10 3.11}" + local python_versions="${TEST_PYTHON_VERSIONS:-3.9 3.10 3.11 3.12}" for python in ${python_versions}; do local pyver=${python/m} @@ -724,7 +849,7 @@ test_macos_wheels() { local platform_tags="x86_64" fi - local python_versions="${TEST_PYTHON_VERSIONS:-3.9 3.10 3.11}" + local python_versions="${TEST_PYTHON_VERSIONS:-3.9 3.10 3.11 3.12}" # verify arch-native wheels inside an arch-native conda environment for python in ${python_versions}; do @@ -811,9 +936,10 @@ test_jars() { : ${TEST_JS:=${TEST_SOURCE}} : ${TEST_GO:=${TEST_SOURCE}} : ${TEST_R:=${TEST_SOURCE}} +: ${TEST_RUST:=${TEST_SOURCE}} # Automatically test if its activated by a dependent -TEST_CPP=$((${TEST_CPP} + ${TEST_GO} + ${TEST_GLIB} + ${TEST_PYTHON})) +TEST_CPP=$((${TEST_CPP} + ${TEST_GO} + ${TEST_GLIB} + ${TEST_PYTHON} + ${TEST_RUST})) # Execute tests in a conda enviroment : ${USE_CONDA:=0} diff --git a/dev/release/versions.env b/dev/release/versions.env index cbc6998e3f..8021f118e2 100644 --- a/dev/release/versions.env +++ b/dev/release/versions.env @@ -17,18 +17,18 @@ # The release as a whole has a counter-based identifier (as in, 12 is the # 12th release of ADBC). This is used to identify tags, branches, and so on. -RELEASE="13" -PREVIOUS_RELEASE="12" +RELEASE="16" +PREVIOUS_RELEASE="15" # Individual components will have a SemVer. -VERSION_CSHARP="0.13.0" -VERSION_JAVA="0.13.0" +VERSION_CSHARP="0.16.0" +VERSION_JAVA="0.16.0" # Because C++/GLib/Go/Python/Ruby are effectively tied at the hip, they share # a single version number. Also covers Conda/Linux packages. -VERSION_NATIVE="1.1.0" -VERSION_R="0.13.0" -VERSION_RUST="0.13.0" +VERSION_NATIVE="1.4.0" +VERSION_R="0.16.0" +VERSION_RUST="0.16.0" # Required by the version bump script -PREVIOUS_VERSION_NATIVE="1.0.0" -PREVIOUS_VERSION_R="0.12.0" +PREVIOUS_VERSION_NATIVE="1.3.0" +PREVIOUS_VERSION_R="0.15.0" diff --git a/docker-compose.yml b/docker-compose.yml index 1c6192f276..bf961cdb84 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -30,10 +30,23 @@ services: command: | /adbc/ci/scripts/csharp_pack.sh /adbc + ################################ C/C++ ####################################### + + # Build/test with latest Clang + cpp-clang-latest: + build: + context: . + dockerfile: ci/docker/cpp-clang-latest.dockerfile + args: + VCPKG: ${VCPKG} + volumes: + - .:/adbc:delegated + command: "bash -c 'export PATH=$PATH:/opt/go/bin CC=$(which clang) CXX=$(which clang++) && git config --global --add safe.directory /adbc && /adbc/ci/scripts/cpp_build.sh /adbc /adbc/build && env BUILD_ALL=0 BUILD_DRIVER_MANAGER=1 BUILD_DRIVER_SQLITE=1 /adbc/ci/scripts/cpp_test.sh /adbc/build'" + ############################ Documentation ################################### docs: - image: condaforge/mambaforge:latest + image: condaforge/miniforge3:latest volumes: - .:/adbc:delegated environment: @@ -66,6 +79,20 @@ services: - .:/adbc:delegated command: "/bin/bash -c '/adbc/ci/scripts/python_conda_test.sh /adbc /adbc/build'" + python-debug: + image: ${REPO}:${ARCH}-python-${PYTHON}-debug-adbc + build: + context: . + cache_from: + - ${REPO}:${ARCH}-python-${PYTHON}-debug-adbc + dockerfile: ci/docker/python-debug.dockerfile + args: + ARCH: ${ARCH} + GO: ${GO} + volumes: + - .:/adbc:delegated + command: /adbc/ci/docker/python-debug.sh + ############################ Python sdist ################################## python-sdist: @@ -94,7 +121,10 @@ services: ############################ Python wheels ################################## - python-wheel-manylinux: + # We build on a different image to use an older base image/glibc, then + # relocate on a separate image so that we can use a newer docker for cibuildwheel + + python-wheel-manylinux-build: image: ${REPO}:${ARCH}-python-${PYTHON}-wheel-manylinux-${MANYLINUX}-vcpkg-${VCPKG}-adbc build: context: . @@ -108,14 +138,28 @@ services: PYTHON: ${PYTHON} REPO: ${REPO} VCPKG: ${VCPKG} + volumes: + - .:/adbc + # Must set safe.directory so go/miniver won't error when calling git + command: "'git config --global --add safe.directory /adbc && /adbc/ci/scripts/python_wheel_unix_build.sh ${ARCH} /adbc /adbc/build'" + + python-wheel-manylinux-relocate: + image: ${REPO}:adbc-python-${PYTHON}-wheel-relocate + platform: ${PLATFORM} + build: + context: . + cache_from: + - ${REPO}:adbc-python-${PYTHON}-wheel-relocate + dockerfile: ci/docker/python-wheel-manylinux-relocate.dockerfile volumes: - /var/run/docker.sock:/var/run/docker.sock - .:/adbc - # Must set safe.directory so miniver won't error when calling git - command: "'git config --global --add safe.directory /adbc && git config --global --get safe.directory && /adbc/ci/scripts/python_wheel_unix_build.sh ${ARCH} /adbc /adbc/build'" + # Must set safe.directory so go/miniver won't error when calling git + command: "bash -c 'git config --global --add safe.directory /adbc && python -m venv /venv && source /venv/bin/activate && /adbc/ci/scripts/python_wheel_unix_relocate.sh ${ARCH} /adbc /adbc/build'" python-wheel-manylinux-test: image: ${ARCH}/python:${PYTHON}-slim + platform: ${PLATFORM} volumes: - .:/adbc:delegated command: /adbc/ci/scripts/python_wheel_unix_test.sh /adbc diff --git a/docs/source/AdbcQuadrants.mmd b/docs/source/AdbcQuadrants.mmd new file mode 100644 index 0000000000..78c4d636c8 --- /dev/null +++ b/docs/source/AdbcQuadrants.mmd @@ -0,0 +1,27 @@ +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. + +quadrantChart + x-axis Database-specific --> Database-agnostic + y-axis Row-oriented --> Arrow-native + + ADBC: [0.9, 0.9] + JDBC: [0.9, 0.1] + ODBC: [0.9, 0.2] + "Flight SQL": [0.6, 0.9] + "BigQuery API": [0.1, 0.8] + "PostgreSQL protocol": [0.4, 0.1] diff --git a/docs/source/AdbcQuadrants.mmd.svg b/docs/source/AdbcQuadrants.mmd.svg new file mode 100644 index 0000000000..4fcc93d847 --- /dev/null +++ b/docs/source/AdbcQuadrants.mmd.svg @@ -0,0 +1,19 @@ + +PostgreSQL protocolBigQuery APIFlight SQLODBCJDBCADBCDatabase-specificDatabase-agnosticRow-orientedArrow-native diff --git a/docs/source/conf.py b/docs/source/conf.py index 0b0974f835..86bb6371e2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -30,7 +30,7 @@ registered trademarks or trademarks of The Apache Software Foundation in the United States and other countries.""" author = "the Apache Arrow Developers" -release = "13 (dev)" +release = "16 (dev)" # Needed to generate version switcher version = release @@ -43,7 +43,6 @@ "adbc_cookbook", # generic directives to enable intersphinx for java "adbc_java_domain", - "breathe", "numpydoc", "sphinx.ext.autodoc", "sphinx.ext.doctest", @@ -78,12 +77,15 @@ "show-inheritance": True, } -# -- Options for Breathe ----------------------------------------------------- - -breathe_default_project = "adbc" -breathe_projects = { - "adbc": "../../c/apidoc/xml/", -} +# https://stackoverflow.com/questions/11417221/sphinx-autodoc-gives-warning-pyclass-reference-target-not-found-type-warning +nitpick_ignore = [ + ("py:class", "abc.ABC"), + ("py:class", "datetime.date"), + ("py:class", "datetime.datetime"), + ("py:class", "datetime.time"), + ("py:class", "enum.Enum"), + ("py:class", "enum.IntEnum"), +] # -- Options for doctest ----------------------------------------------------- @@ -116,6 +118,8 @@ intersphinx_mapping = { "arrow": ("https://arrow.apache.org/docs/", None), + "pandas": ("https://pandas.pydata.org/docs/", None), + "polars": ("https://docs.pola.rs/api/python/stable/", None), } # Add env vars like ADBC_INTERSPHINX_MAPPING_adbc_java = url;path @@ -130,10 +134,6 @@ def _find_intersphinx_mappings(): url, _, path = val.partition(";") print("[ADBC] Found Intersphinx mapping", name) intersphinx_mapping[name] = (url, path) - # "adbc_java": ( - # "http://localhost:8000/", - # "/home/lidavidm/Code/arrow-adbc/java/target/site/apidocs/objects.inv", - # ), _find_intersphinx_mappings() diff --git a/docs/source/cpp/api/index.rst b/docs/source/cpp/api/index.rst index 0011d51949..1b5bbe9104 100644 --- a/docs/source/cpp/api/index.rst +++ b/docs/source/cpp/api/index.rst @@ -19,8 +19,5 @@ C/C++ API Reference =================== -.. toctree:: - :maxdepth: 2 - - adbc - adbc_driver_manager +This is a stub page for the Doxygen documentation. If you're seeing this page, +it means that the actual documentation was not generated. diff --git a/docs/source/cpp/concurrency.rst b/docs/source/cpp/concurrency.rst index d7cfbae318..1920b1a8b6 100644 --- a/docs/source/cpp/concurrency.rst +++ b/docs/source/cpp/concurrency.rst @@ -47,7 +47,7 @@ AdbcConnection: /* What happens to the result set of stmt1? */ What happens if the client application calls -:cpp:func:`AdbcStatementExecuteQuery` on ``stmt1``, then on ``stmt2``, +:c:func:`AdbcStatementExecuteQuery` on ``stmt1``, then on ``stmt2``, without reading the result set of ``stmt1``? Some existing client libraries/protocols, like libpq, don't support concurrent execution of queries from a single connection. So the driver would have to diff --git a/docs/source/cpp/api/adbc_driver_manager.rst b/docs/source/cpp/driver_example.rst similarity index 51% rename from docs/source/cpp/api/adbc_driver_manager.rst rename to docs/source/cpp/driver_example.rst index f624ce993c..0776b5e8ef 100644 --- a/docs/source/cpp/api/adbc_driver_manager.rst +++ b/docs/source/cpp/driver_example.rst @@ -15,8 +15,35 @@ .. specific language governing permissions and limitations .. under the License. -========================= -``adbc_driver_manager.h`` -========================= +============== +Driver Example +============== -.. doxygenfile:: adbc_driver_manager.h +.. recipe:: recipe_driver/driver_example.cc + :language: cpp + +Low-level testing +================= + +.. recipe:: recipe_driver/driver_example_test.cc + :language: cpp + +High-level testing +================== + +.. recipe:: recipe_driver/driver_example.py + +High-level tests can also be written in R using the ``adbcdrivermanager`` +package. + +.. code-block:: r + + library(adbcdrivermanager) + + drv <- adbc_driver("build/libdriver_example.dylib") + db <- adbc_database_init(drv, uri = paste0("file://", getwd())) + con <- adbc_connection_init(db) + + data.frame(col = 1:3) |> write_adbc(con, "example.arrows") + con |> read_adbc("SELECT * FROM example.arrows") |> as.data.frame() + unlink("example.arrows") diff --git a/docs/source/cpp/driver_manager.rst b/docs/source/cpp/driver_manager.rst index d8db791d1f..ed02ec532d 100644 --- a/docs/source/cpp/driver_manager.rst +++ b/docs/source/cpp/driver_manager.rst @@ -94,13 +94,12 @@ Then they can be used via CMake, e.g.: Usage ===== -To create a database, use the :cpp:class:`AdbcDatabase` API as usual, -but during initialization, provide two additional parameters in -addition to the driver-specific connection parameters: ``driver`` and -(optionally) ``entrypoint``. ``driver`` must be the name of a library -to load, or the path to a library to load. ``entrypoint``, if -provided, should be the name of the symbol that serves as the ADBC -entrypoint (see :cpp:type:`AdbcDriverInitFunc`). +To create a database, use the :c:struct:`AdbcDatabase` API as usual, but +during initialization, provide two additional parameters in addition to the +driver-specific connection parameters: ``driver`` and (optionally) +``entrypoint``. ``driver`` must be the name of a library to load, or the path +to a library to load. ``entrypoint``, if provided, should be the name of the +symbol that serves as the ADBC entrypoint (see :c:type:`AdbcDriverInitFunc`). .. code-block:: c @@ -120,5 +119,5 @@ entrypoint (see :cpp:type:`AdbcDriverInitFunc`). API Reference ============= -The driver manager includes a few additional functions beyond the ADBC -API. See the API reference: :doc:`./api/adbc_driver_manager`. +The driver manager includes a few additional functions beyond the ADBC API. +See the API reference: :external+cpp_adbc:doc:`adbc_driver_manager.h`. diff --git a/docs/source/cpp/index.rst b/docs/source/cpp/index.rst index add0e29efe..29bbfc9202 100644 --- a/docs/source/cpp/index.rst +++ b/docs/source/cpp/index.rst @@ -25,4 +25,5 @@ C and C++ quickstart driver_manager concurrency + driver_example api/index diff --git a/docs/source/cpp/recipe/quickstart.cc b/docs/source/cpp/recipe/quickstart.cc index c1e68565ef..b9ae384f01 100644 --- a/docs/source/cpp/recipe/quickstart.cc +++ b/docs/source/cpp/recipe/quickstart.cc @@ -62,7 +62,7 @@ #include #include -#include +#include #include /// Then we'll add some (very basic) error checking helpers. @@ -139,7 +139,7 @@ int main() { /// ----------------- /// /// We execute a query by setting the query on the statement, then - /// calling :cpp:func:`AdbcStatementExecuteQuery`. The results come + /// calling :c:func:`AdbcStatementExecuteQuery`. The results come /// back through the `Arrow C Data Interface`_. /// /// .. _Arrow C Data Interface: https://arrow.apache.org/docs/format/CDataInterface.html diff --git a/docs/source/cpp/recipe_driver/CMakeLists.txt b/docs/source/cpp/recipe_driver/CMakeLists.txt new file mode 100644 index 0000000000..8e1159a855 --- /dev/null +++ b/docs/source/cpp/recipe_driver/CMakeLists.txt @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.18) + +project(adbc_cookbook_recipes_driver + VERSION "1.0.0" + LANGUAGES CXX) + +include(CTest) +include(FetchContent) + +set(CMAKE_CXX_STANDARD 17) + +set(NANOARROW_IPC ON) +set(NANOARROW_NAMESPACE "DriverExamplePrivate") +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +fetchcontent_declare(nanoarrow + URL "https://www.apache.org/dyn/closer.lua?action=download&filename=arrow/apache-arrow-nanoarrow-0.6.0/apache-arrow-nanoarrow-0.6.0.tar.gz" + URL_HASH SHA256=e4a02ac51002ad1875bf09317e70adb959005fad52b240ff59f73b970fa485d1 +) +fetchcontent_makeavailable(nanoarrow) + +# TODO: We could allow this to be installed + linked to as a target; however, +# fetchcontent is a little nicer for this kind of thing (statically linked +# pinned version of something that doesn't rely on a system library). +add_library(adbc_driver_framework ../../../../c/driver/framework/utility.cc + ../../../../c/driver/framework/objects.cc) +target_include_directories(adbc_driver_framework PRIVATE ../../../../c + ../../../../c/include) +target_link_libraries(adbc_driver_framework PRIVATE nanoarrow::nanoarrow) + +add_library(driver_example SHARED driver_example.cc) +target_include_directories(driver_example PRIVATE ../../../../c ../../../../c/include) +target_link_libraries(driver_example PRIVATE adbc_driver_framework + nanoarrow::nanoarrow_ipc) + +if(ADBC_DRIVER_EXAMPLE_BUILD_TESTS) + fetchcontent_declare(googletest + URL https://github.com/google/googletest/archive/refs/tags/v1.15.1.tar.gz + URL_HASH SHA256=5052e088b16bdd8c6f0c7f9cafc942fd4f7c174f1dac6b15a8dd83940ed35195 + ) + fetchcontent_makeavailable(googletest) + + find_package(AdbcDriverManager REQUIRED) + + add_executable(driver_example_test driver_example_test.cc) + target_link_libraries(driver_example_test + PRIVATE gtest_main driver_example + AdbcDriverManager::adbc_driver_manager_shared) + + include(GoogleTest) + gtest_discover_tests(driver_example_test) + +endif() diff --git a/docs/source/cpp/recipe_driver/driver_example.cc b/docs/source/cpp/recipe_driver/driver_example.cc new file mode 100644 index 0000000000..9c7ae75f8f --- /dev/null +++ b/docs/source/cpp/recipe_driver/driver_example.cc @@ -0,0 +1,304 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// RECIPE STARTS HERE + +/// Here we'll show the structure of building an ADBC driver in C++ using +/// the ADBC driver framework library. This is the same library that ADBC +/// uses to build its SQLite and PostgreSQL drivers and abstracts away +/// the details of C callables and catalog/metadata functions that can be +/// difficult to implement but are essential for efficiently leveraging +/// the rest of the ADBC ecosystem. +/// +/// At a high level, we'll be building a driver whose "database" is a directory +/// where each "table" in the database is a file containing an Arrow IPC stream. +/// Tables can be written using the bulk ingest feature and tables can be read +/// with a simple query in the form ``SELECT * FROM (the file)``. +/// +/// Installation +/// ============ +/// +/// This quickstart is actually a literate C++ file. You can clone +/// the repository, build the sample, and follow along. +/// +/// We'll assume you're using conda-forge_ for dependencies. CMake, a +/// C++17 compiler, and the ADBC libraries are required. They can be +/// installed as follows: +/// +/// .. code-block:: shell +/// +/// mamba install cmake compilers libadbc-driver-manager +/// +/// .. _conda-forge: https://conda-forge.org/ +/// +/// Building +/// ======== +/// +/// We'll use CMake_ here. From a source checkout of the ADBC repository: +/// +/// .. code-block:: shell +/// +/// mkdir build +/// cd build +/// cmake ../docs/source/cpp/recipe_driver -DADBC_DRIVER_EXAMPLE_BUILD_TESTS=ON +/// cmake --build . +/// ctest +/// +/// .. _CMake: https://cmake.org/ +/// +/// Building an ADBC Driver using C++ +/// ================================= +/// +/// Let's start with some includes. Notably, we'll need the driver framework +/// header files and nanoarrow_, which we'll use to create and consume the +/// Arrow C data interface structures in this example driver. + +/// .. _nanoarrow: https://arrow.apache.org/nanoarrow + +#include "driver_example.h" + +#include +#include + +#include "driver/framework/connection.h" +#include "driver/framework/database.h" +#include "driver/framework/statement.h" + +#include "nanoarrow/nanoarrow.hpp" +#include "nanoarrow/nanoarrow_ipc.hpp" + +#include "arrow-adbc/adbc.h" + +/// Next, we'll bring a few essential framework types into the namespace +/// to reduce the verbosity of the implementation: +/// +/// * :cpp:class:`adbc::driver::Option` : Options can be set on an ADBC database, +/// connection, and statmenent. They can be strings, opaque binary, doubles, or +/// integers. The ``Option`` class abstracts the details of how to get, set, +/// and parse these values. +/// * :cpp:class:`adbc::driver::Status`: The ``Status`` is the ADBC driver +/// framework's error handling mechanism: functions with no return value that +/// can fail return a ``Status``. You can use ``UNWRAP_STATUS(some_call())`` as +/// shorthand for ``Status status = some_call(); if (!status.ok()) return +/// status;`` to succinctly propagate errors. +/// * :cpp:class:`adbc::driver::Result`: The ``Result`` is used as a return +/// value for functions that on success return a value of type ``T`` and on +/// failure communicate their error using a ``Status``. You can use +/// ``UNWRAP_RESULT(some_type value, some_call())`` as shorthand for +/// +/// .. code-block:: cpp +/// +/// some_type value; +/// Result maybe_value = some_call(); +/// if (!maybe_value.status().ok()) { +/// return maybe_value.status(); +/// } else { +/// value = *maybe_value; +/// } + +using adbc::driver::Option; +using adbc::driver::Result; +using adbc::driver::Status; + +namespace { + +/// Next, we'll provide the database implementation. The driver framework uses +/// the Curiously Recurring Template Pattern (CRTP_). The details of this are +/// handled by the framework, but functionally this is still just overriding +/// methods from a base class that handles the details. +/// +/// Here, our database implementation will simply record the ``uri`` passed +/// by the user. Our interpretation of this will be a ``file://`` uri to +/// a directory to which our IPC files should be written and/or IPC files +/// should be read. This is the role of the database in ADBC: a shared +/// handle to a database that potentially caches some shared state among +/// connections, but which still allows multiple connections to execute +/// against the database concurrently. +/// +/// .. _CRTP: https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern + +class DriverExampleDatabase : public adbc::driver::Database { + public: + [[maybe_unused]] constexpr static std::string_view kErrorPrefix = "[example]"; + + Status SetOptionImpl(std::string_view key, Option value) override { + // Handle and validate options implemented by this driver + if (key == "uri") { + UNWRAP_RESULT(std::string_view uri, value.AsString()); + + if (uri.find("file://") != 0) { + return adbc::driver::status::InvalidArgument( + "[example] uri must start with 'file://'"); + } + + uri_ = uri; + return adbc::driver::status::Ok(); + } + + // Defer to the base implementation to handle state managed by the base + // class (and error for all other options). + return Base::SetOptionImpl(key, value); + } + + Result